diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 65f9a4dcf2b..3c39166c1e8 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -151,7 +151,7 @@ jobs: - name: Install dependencies run: | find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; - python -m pip install --user --upgrade pip setuptools wheel twine + python -m pip install --user --upgrade pip setuptools wheel twine packaging # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated # fresh torch installation according to pyproject.toml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 60b610565ea..a014a4ed1db 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -24,7 +24,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install setuptools run: | - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel packaging - name: Build and test source archive and wheel file run: | find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; @@ -104,7 +104,7 @@ jobs: run: | find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; git describe - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel packaging python setup.py build cat build/lib/monai/_version.py - name: Upload version diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index e94e1dac5a9..8d8cccffad1 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -5,6 +5,39 @@ on: - cron: "0 2 * * 0" # 02:00 of every Sunday jobs: + flake8-py3: + runs-on: ubuntu-latest + strategy: + matrix: + opt: ["codeformat", "pytype", "mypy"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + - name: cache weekly timestamp + id: pip-cache + run: | + echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT + - name: cache for pip + uses: actions/cache@v4 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install dependencies + run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; + python -m pip install --upgrade pip wheel + python -m pip install -r requirements-dev.txt + - name: Lint and type check + run: | + # clean up temporary files + $(pwd)/runtests.sh --build --clean + # Github actions have 2 cores, so parallelize pytype + $(pwd)/runtests.sh --build --${{ matrix.opt }} -j 2 + packaging: if: github.repository == 'Project-MONAI/MONAI' runs-on: ubuntu-latest @@ -19,7 +52,7 @@ jobs: python-version: '3.9' - name: Install setuptools run: | - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel packaging - name: Build distribution run: | export HEAD_COMMIT_ID=$(git rev-parse HEAD) diff --git a/README.md b/README.md index 5345cdb9264..498d3c61499 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ [![postmerge](https://img.shields.io/github/checks-status/project-monai/monai/dev?label=postmerge)](https://github.com/Project-MONAI/MONAI/actions?query=branch%3Adev) [![Documentation Status](https://readthedocs.org/projects/monai/badge/?version=latest)](https://docs.monai.io/en/latest/) [![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg?token=6FTC7U1JJ4)](https://codecov.io/gh/Project-MONAI/MONAI) +[![monai Downloads Last Month](https://assets.piptrends.com/get-last-month-downloads-badge/monai.svg 'monai Downloads Last Month by pip Trends')](https://piptrends.com/package/monai) MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/). Its ambitions are: diff --git a/docs/requirements.txt b/docs/requirements.txt index fe415a07b57..ff94f7b6dea 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -41,3 +41,4 @@ onnxruntime; python_version <= '3.10' zarr huggingface_hub pyamg>=5.0.0 +packaging diff --git a/docs/source/apps.rst b/docs/source/apps.rst index c6ba8c0b9a2..cc4cea8c1e8 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -248,6 +248,22 @@ FastMRIReader ~~~~~~~~~~~~~ .. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj +`Vista3d` +--------- +.. automodule:: monai.apps.vista3d.inferer +.. autofunction:: point_based_window_inferer + +.. automodule:: monai.apps.vista3d.transforms +.. autoclass:: VistaPreTransformd + :members: +.. autoclass:: VistaPostTransformd + :members: +.. autoclass:: Relabeld + :members: + +.. automodule:: monai.apps.vista3d.sampler +.. autofunction:: sample_prompt_pairs + `Auto3DSeg` ----------- .. automodule:: monai.apps.auto3dseg @@ -261,11 +277,3 @@ FastMRIReader .. autoclass:: monai.apps.nnunet.nnUNetV2Runner :members: - -`Generative AI` ---------------- - -`MAISI Utilities` -~~~~~~~~~~~~~~~~~ -.. automodule:: monai.apps.generation.maisi.utils.morphological_ops - :members: diff --git a/docs/source/installation.md b/docs/source/installation.md index 4308a076471..70a8b6f1d4e 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, segment-anything] ``` which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `segment-anything` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 249375dfc17..1810fec49ba 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -481,6 +481,11 @@ Nets .. autoclass:: SegResNetDS :members: +`SegResNetDS2` +~~~~~~~~~~~~~~ +.. autoclass:: SegResNetDS2 + :members: + `SegResNetVAE` ~~~~~~~~~~~~~~ .. autoclass:: SegResNetVAE @@ -556,6 +561,11 @@ Nets .. autoclass:: UNETR :members: +`VISTA3D` +~~~~~~~~~ +.. autoclass:: VISTA3D + :members: + `SwinUNETR` ~~~~~~~~~~~ .. autoclass:: SwinUNETR diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a3598216791..637f0873f14 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -2310,6 +2310,9 @@ Utilities .. automodule:: monai.transforms.utils_pytorch_numpy_unification :members: +.. automodule:: monai.transforms.utils_morphological_ops + :members: + By Categories ------------- .. toctree:: diff --git a/environment-dev.yml b/environment-dev.yml index d23958baba6..a4651ec7e41 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -5,10 +5,10 @@ channels: - nvidia - conda-forge dependencies: - - numpy>=1.20 + - numpy>=1.24,<2.0 - pytorch>=1.9 - torchvision - - pytorch-cuda=11.6 + - pytorch-cuda>=11.6 - pip - pip: - -r requirements-dev.txt diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index f27f73ec60e..a52274b24a3 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -13,25 +13,17 @@ import gc import logging -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from monai.networks.blocks import Convolution -from monai.utils import optional_import +from monai.networks.blocks.spatialattention import SpatialAttentionBlock +from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL from monai.utils.type_conversion import convert_to_tensor -AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") -AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") -ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") - -if TYPE_CHECKING: - from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType -else: - AutoencoderKLType = cast(type, AutoencoderKL) - # Set up logging configuration logger = logging.getLogger(__name__) @@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module): in_channels: Number of input channels. num_channels: Sequence of block output channels. out_channels: Number of channels in the bottom layer (latent space) of the autoencoder. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. @@ -547,6 +541,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -603,11 +599,13 @@ def __init__( input_channel = output_channel if attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -626,7 +624,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -636,16 +634,18 @@ def __init__( ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=num_channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module): num_channels: Sequence of block output channels. in_channels: Number of channels in the bottom layer (latent space) of the autoencoder. out_channels: Number of output channels. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. @@ -729,6 +731,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_convtranspose: bool = False, ) -> None: @@ -758,7 +762,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -767,16 +771,18 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -812,11 +818,13 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -870,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class AutoencoderKlMaisi(AutoencoderKLType): +class AutoencoderKlMaisi(AutoencoderKL): """ AutoencoderKL with custom MaisiEncoder and MaisiDecoder. @@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType): norm_eps: Epsilon for the normalization. with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder. with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_checkpointing: If True, use activation checkpointing. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. @@ -909,6 +919,8 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = False, with_decoder_nonlocal_attn: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_checkpointing: bool = False, use_convtranspose: bool = False, @@ -930,12 +942,14 @@ def __init__( norm_eps, with_encoder_nonlocal_attn, with_decoder_nonlocal_attn, - use_flash_attention, use_checkpointing, use_convtranspose, + include_fc, + use_combined_linear, + use_flash_attention, ) - self.encoder = MaisiEncoder( + self.encoder: nn.Module = MaisiEncoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, @@ -945,6 +959,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, num_splits=num_splits, dim_split=dim_split, @@ -953,7 +969,7 @@ def __init__( save_mem=save_mem, ) - self.decoder = MaisiDecoder( + self.decoder: nn.Module = MaisiDecoder( spatial_dims=spatial_dims, num_channels=num_channels, in_channels=latent_channels, @@ -963,6 +979,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, num_splits=num_splits, diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 3641124b7da..269086d971e 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -11,24 +11,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch -from monai.utils import optional_import +from monai.networks.nets.controlnet import ControlNet +from monai.networks.nets.diffusion_model_unet import get_timestep_embedding -ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" -) -if TYPE_CHECKING: - from generative.networks.nets.controlnet import ControlNet as ControlNetType -else: - ControlNetType = cast(type, ControlNet) - - -class ControlNetMaisi(ControlNetType): +class ControlNetMaisi(ControlNet): """ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image Diffusion Models" (https://arxiv.org/abs/2302.05543) @@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. use_checkpointing: if True, use activation checkpointing to save memory. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -71,10 +64,12 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, - conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), use_checkpointing: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__( spatial_dims, @@ -91,9 +86,11 @@ def __init__( cross_attention_dim, num_class_embeds, upcast_attention, - use_flash_attention, conditioning_embedding_in_channels, conditioning_embedding_num_channels, + include_fc, + use_combined_linear, + use_flash_attention, ) self.use_checkpointing = use_checkpointing @@ -105,7 +102,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index d5f5f6136ba..e990b5fc985 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -37,21 +37,15 @@ from torch import nn from monai.networks.blocks import Convolution -from monai.utils import ensure_tuple_rep, optional_import -from monai.utils.type_conversion import convert_to_tensor - -get_down_block, has_get_down_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_down_block" -) -get_mid_block, has_get_mid_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_mid_block" -) -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +from monai.networks.nets.diffusion_model_unet import ( + get_down_block, + get_mid_block, + get_timestep_embedding, + get_up_block, + zero_module, ) -get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block") -xformers, has_xformers = optional_import("xformers") -zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module") +from monai.utils import ensure_tuple_rep +from monai.utils.type_conversion import convert_to_tensor __all__ = ["DiffusionModelUNetMaisi"] @@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module): cross_attention_dim: Number of context dimensions to use. num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: If True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers. include_top_region_index_input: If True, use top region index input. @@ -102,6 +98,8 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, include_top_region_index_input: bool = False, @@ -152,9 +150,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." @@ -210,7 +205,6 @@ def __init__( input_channel = output_channel output_channel = num_channels[i] is_final_block = i == len(num_channels) - 1 - down_block = get_down_block( spatial_dims=spatial_dims, in_channels=input_channel, @@ -227,6 +221,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -245,6 +241,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -280,6 +278,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) diff --git a/monai/apps/generation/maisi/utils/__init__.py b/monai/apps/vista3d/__init__.py similarity index 100% rename from monai/apps/generation/maisi/utils/__init__.py rename to monai/apps/vista3d/__init__.py diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py new file mode 100644 index 00000000000..709f81f6243 --- /dev/null +++ b/monai/apps/vista3d/inferer.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +from collections.abc import Sequence +from typing import Any + +import torch + +from monai.data.meta_tensor import MetaTensor +from monai.utils import optional_import + +tqdm, _ = optional_import("tqdm", name="tqdm") + +__all__ = ["point_based_window_inferer"] + + +def point_based_window_inferer( + inputs: torch.Tensor | MetaTensor, + roi_size: Sequence[int], + predictor: torch.nn.Module, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + prev_mask: torch.Tensor | MetaTensor | None = None, + point_start: int = 0, + center_only: bool = True, + margin: int = 5, + **kwargs: Any, +) -> torch.Tensor: + """ + Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. + The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by + patch inference and average output stitching, and finally returns the segmented mask. + + Args: + inputs: [1CHWD], input image to be processed. + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. + Add transpose=True in kwargs for vista3d. + point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points. + point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes. + 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel). + class_vector: [B]. Used for class-head automatic segmentation. Can be None value. + prompt_class: [B]. The same as class_vector representing the point class and inform point head about + supported class or zeroshot, not used for automatic segmentation. If None, point head is default + to supported class segmentation. + prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks. + point_start: only use points starting from this number. All points before this number is used to generate + prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask. + center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point. + margin: if center_only is false, this value is the distance between point to the patch boundary. + Returns: + stitched_output: [1, B, H, W, D]. The value is before sigmoid. + Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. + """ + if not point_coords.shape[0] == 1: + raise ValueError("Only supports single object point click.") + if not len(inputs.shape) == 5: + raise ValueError("Input image should be 5D.") + image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size) + point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) + prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None + stitched_output = None + for p in point_coords[0][point_start:]: + lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin) + ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin) + lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin) + for i in range(len(lx_)): + for j in range(len(ly_)): + for k in range(len(lz_)): + lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k]) + unravel_slice = [ + slice(None), + slice(None), + slice(int(lx), int(rx)), + slice(int(ly), int(ry)), + slice(int(lz), int(rz)), + ] + batch_image = image[unravel_slice] + output = predictor( + batch_image, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + patch_coords=unravel_slice, + prev_mask=prev_mask, + **kwargs, + ) + if stitched_output is None: + stitched_output = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_mask = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_output[unravel_slice] += output.to("cpu") + stitched_mask[unravel_slice] = 1 + # if stitched_mask is 0, then NaN value + stitched_output = stitched_output / stitched_mask + # revert padding + stitched_output = stitched_output[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + stitched_mask = stitched_mask[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + if prev_mask is not None: + prev_mask = prev_mask[ + :, + :, + pad[4] : image.shape[-3] - pad[5], + pad[2] : image.shape[-2] - pad[3], + pad[0] : image.shape[-1] - pad[1], + ] + prev_mask = prev_mask.to("cpu") # type: ignore + # for un-calculated place, use previous mask + stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1] + if isinstance(inputs, torch.Tensor): + inputs = MetaTensor(inputs) + if not hasattr(stitched_output, "meta"): + stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta) + return stitched_output + + +def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]: + """Helper function to get the window index.""" + if p - roi // 2 < 0: + left, right = 0, roi + elif p + roi // 2 > s: + left, right = s - roi, s + else: + left, right = int(p) - roi // 2, int(p) + roi // 2 + return left, right + + +def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]: + """Get the window index.""" + left, right = _get_window_idx_c(p, roi, s) + if center_only: + return [left], [right] + left_most = max(0, p - roi + margin) + right_most = min(s, p + roi - margin) + left_list = [left_most, right_most - roi, left] + right_list = [left_most + roi, right_most, right] + return left_list, right_list + + +def _pad_previous_mask( + inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0 +) -> tuple[torch.Tensor | MetaTensor, list[int]]: + """Helper function to pad inputs.""" + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + if any(pad_size): + inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore + return inputs, pad_size diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py new file mode 100644 index 00000000000..b7aeb89a2ed --- /dev/null +++ b/monai/apps/vista3d/sampler.py @@ -0,0 +1,172 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import random +from collections.abc import Callable, Sequence +from typing import Any + +import numpy as np +import torch +from torch import Tensor + +__all__ = ["sample_prompt_pairs"] + +ENABLE_SPECIAL = True +SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) +MERGE_LIST = { + 1: [25, 26], # hepatic tumor and vessel merge into liver + 4: [24], # pancreatic tumor merge into pancreas + 132: [57], # overlap with trachea merge into airway +} + + +def _get_point_label(id: int) -> tuple[int, int]: + if id in SPECIAL_INDEX and ENABLE_SPECIAL: + return 2, 3 + else: + return 0, 1 + + +def sample_prompt_pairs( + labels: Tensor, + label_set: Sequence[int], + max_prompt: int | None = None, + max_foreprompt: int | None = None, + max_backprompt: int = 1, + max_point: int = 20, + include_background: bool = False, + drop_label_prob: float = 0.2, + drop_point_prob: float = 0.2, + point_sampler: Callable | None = None, + **point_sampler_kwargs: Any, +) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + """ + Sample training pairs for VISTA3D training. + + Args: + labels: [1, 1, H, W, D], ground truth labels. + label_set: the label list for the specific dataset. Note if 0 is included in label_set, + it will be added into automatic branch training. Recommend removing 0 from label_set + for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset. + The reason is region with 0 in one partially labeled dataset may contain foregrounds in + another dataset. + max_prompt: int, max number of total prompt, including foreground and background. + max_foreprompt: int, max number of prompt from foreground. + max_backprompt: int, max number of prompt from background. + max_point: maximum number of points for each object. + include_background: if include 0 into training prompt. If included, background 0 is treated + the same as foreground. Always be False for multi-partial-dataset training. If needed, + can be true for finetuning specific dataset, . + drop_label_prob: probability to drop label prompt. + drop_point_prob: probability to drop point prompt. + point_sampler: sampler to augment masks with supervoxel. + point_sampler_kwargs: arguments for point_sampler. + + Returns: + label_prompt: [B, 1]. The classes used for training automatic segmentation. + point: [B, N, 3]. The corresponding points for each class. + Note that background label prompt requires matching point as well ([0,0,0] is used). + point_label: [B, N]. The corresponding point labels for each point (negative or positive). + -1 is used for padding the background label prompt and will be ignored. + prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. + label_prompt can be None, and prompt_class is used to identify point classes. + """ + # class label number + if not labels.shape[0] == 1: + raise ValueError("only support batch size 1") + labels = labels[0, 0] + device = labels.device + unique_labels = labels.unique().cpu().numpy().tolist() + if include_background: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set))) + else: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0}) + background_labels = list(set(label_set) - set(unique_labels)) + # during training, balance background and foreground prompts + if max_backprompt is not None: + if len(background_labels) > max_backprompt: + random.shuffle(background_labels) + background_labels = background_labels[:max_backprompt] + + if max_foreprompt is not None: + if len(unique_labels) > max_foreprompt: + random.shuffle(unique_labels) + unique_labels = unique_labels[:max_foreprompt] + + if max_prompt is not None: + if len(unique_labels) + len(background_labels) > max_prompt: + if len(unique_labels) > max_prompt: + unique_labels = random.sample(unique_labels, max_prompt) + background_labels = [] + else: + background_labels = random.sample(background_labels, max_prompt - len(unique_labels)) + _point = [] + _point_label = [] + # if use regular sampling + if point_sampler is None: + num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) + num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) + for id in unique_labels: + neg_id, pos_id = _get_point_label(id) + plabels = labels == int(id) + nlabels = ~plabels + plabelpoints = torch.nonzero(plabels) + nlabelpoints = torch.nonzero(nlabels) + # final sampled positive points + num_pa = min(len(plabelpoints), num_p) + # final sampled negative points + num_na = min(len(nlabelpoints), num_n) + _point.append( + torch.stack( + random.choices(plabelpoints, k=num_pa) + + random.choices(nlabelpoints, k=num_na) + + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na) + ) + ) + _point_label.append( + torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to( + device + ) + ) + for _ in background_labels: + # pad the background labels + _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 + _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point + else: + _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs) + for _ in background_labels: + # pad the background labels + _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0 + _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point + if len(unique_labels) == 0 and len(background_labels) == 0: + # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must + # be skipped. Handle this in trainer. + label_prompt, point, point_label, prompt_class = None, None, None, None + else: + label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long() + point = torch.stack(_point) + point_label = torch.stack(_point_label) + prompt_class = copy.deepcopy(label_prompt) + if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0: + label_prompt = None + # If label prompt is dropped, there is no need to pad with points with label -1. + pad = len(background_labels) + point = point[: len(point) - pad] # type: ignore + point_label = point_label[: len(point_label) - pad] + prompt_class = prompt_class[: len(prompt_class) - pad] + else: + if random.uniform(0, 1) < drop_point_prob: + point = None + point_label = None + return label_prompt, point, point_label, prompt_class diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py new file mode 100644 index 00000000000..3e8145cd80b --- /dev/null +++ b/monai/apps/vista3d/transforms.py @@ -0,0 +1,224 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import Sequence + +import numpy as np +import torch + +from monai.config import DtypeLike, KeysCollection +from monai.transforms import MapLabelValue +from monai.transforms.transform import MapTransform +from monai.transforms.utils import keep_components_with_positive_points +from monai.utils import look_up_option + +__all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"] + + +def _get_name_to_index_mapping(labels_dict: dict | None) -> dict: + """get the label name to index mapping""" + name_to_index_mapping = {} + if labels_dict is not None: + name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()} + return name_to_index_mapping + + +def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None: + """convert the label name to index""" + if label_prompt is not None and isinstance(label_prompt, list): + converted_label_prompt = [] + # for new class, add to the mapping + for l in label_prompt: + if isinstance(l, str) and not l.isdigit(): + if l.lower() not in name_to_index_mapping: + name_to_index_mapping[l.lower()] = len(name_to_index_mapping) + for l in label_prompt: + if isinstance(l, (int, str)): + converted_label_prompt.append( + name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l) + ) + else: + converted_label_prompt.append(l) + return converted_label_prompt + return label_prompt + + +class VistaPreTransformd(MapTransform): + def __init__( + self, + keys: KeysCollection, + allow_missing_keys: bool = False, + special_index: Sequence[int] = (25, 26, 27, 28, 29, 117), + labels_dict: dict | None = None, + subclass: dict | None = None, + ) -> None: + """ + Pre-transform for Vista3d. + + It performs two functionalities: + + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive). + + 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. + e.g. "lung" label is converted to ["left lung", "right lung"]. + + The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, + where each element is an int value of length [B, N]. + + Args: + keys: keys of the corresponding items to be transformed. + special_index: the index that defines the special class. + subclass: a dictionary that maps a label prompt to its subclasses. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.special_index = special_index + self.subclass = subclass + self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict) + + def __call__(self, data): + label_prompt = data.get("label_prompt", None) + point_labels = data.get("point_labels", None) + # convert the label name to index if needed + label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt) + try: + # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator. + if self.subclass is not None and label_prompt is not None: + _label_prompt = [] + subclass_keys = list(map(int, self.subclass.keys())) + for i in range(len(label_prompt)): + if label_prompt[i] in subclass_keys: + _label_prompt.extend(self.subclass[str(label_prompt[i])]) + else: + _label_prompt.append(label_prompt[i]) + data["label_prompt"] = _label_prompt + if label_prompt is not None and point_labels is not None: + if label_prompt[0] in self.special_index: + point_labels = np.array(point_labels) + point_labels[point_labels == 0] = 2 + point_labels[point_labels == 1] = 3 + point_labels = point_labels.tolist() + data["point_labels"] = point_labels + except Exception: + # There is specific requirements for `label_prompt` and `point_labels`. + # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None. + # Those formatting errors should be captured later. + warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.") + + return data + + +class VistaPostTransformd(MapTransform): + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + """ + Post-transform for Vista3d. It converts the model output logits into final segmentation masks. + If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...], + else the indexes will be [0, label_prompt[0], label_prompt[1], ...]. + If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove + regions that does not contain positive points. + + Args: + keys: keys of the corresponding items to be transformed. + dataset_transforms: a dictionary specifies the transform for corresponding dataset: + key: dataset name, value: list of data transforms. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + + def __call__(self, data): + """data["label_prompt"] should not contain 0""" + for keys in self.keys: + if keys in data: + pred = data[keys] + object_num = pred.shape[0] + device = pred.device + if data.get("label_prompt", None) is None and data.get("points", None) is not None: + pred = keep_components_with_positive_points( + pred.unsqueeze(0), + point_coords=data.get("points").to(device), + point_labels=data.get("point_labels").to(device), + )[0] + pred[pred < 0] = 0.0 + # if it's multichannel, perform argmax + if object_num > 1: + # concate background channel. Make sure user did not provide 0 as prompt. + is_bk = torch.all(pred <= 0, dim=0, keepdim=True) + pred = pred.argmax(0).unsqueeze(0).float() + 1.0 + pred[is_bk] = 0.0 + else: + # AsDiscrete will remove NaN + # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred) + pred[pred > 0] = 1.0 + if "label_prompt" in data and data["label_prompt"] is not None: + pred += 0.5 # inplace mapping to avoid cloning pred + label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device + for i in range(1, object_num + 1): + frac = i + 0.5 + pred[pred == frac] = label_prompt[i - 1].to(pred.dtype) + pred[pred == 0.5] = 0.0 + data[keys] = pred + return data + + +class Relabeld(MapTransform): + def __init__( + self, + keys: KeysCollection, + label_mappings: dict[str, list[tuple[int, int]]], + dtype: DtypeLike = np.int16, + dataset_key: str = "dataset_name", + allow_missing_keys: bool = False, + ) -> None: + """ + Remap the voxel labels in the input data dictionary based on the specified mapping. + + This list of local -> global label mappings will be applied to each input `data[keys]`. + if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used. + if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed. + + Args: + keys: keys of the corresponding items to be transformed. + label_mappings: a dictionary specifies how local dataset class indices are mapped to the + global class indices. The dictionary keys are dataset names and the values are lists of + list of (local label, global label) pairs. This list of local -> global label mappings + will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`, + label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None, + no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform. + dtype: convert the output data to dtype, default to float32. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.mappers = {} + self.dataset_key = dataset_key + for name, mapping in label_mappings.items(): + self.mappers[name] = MapLabelValue( + orig_labels=[int(pair[0]) for pair in mapping], + target_labels=[int(pair[1]) for pair in mapping], + dtype=dtype, + ) + + def __call__(self, data): + d = dict(data) + dataset_name = d.get(self.dataset_key, "default") + _m = look_up_option(dataset_name, self.mappers, default=None) + if _m is None: + return d + for key in self.key_iterator(d): + d[key] = _m(d[key]) + return d diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6dd83c1f81c..142a3666694 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -18,6 +18,7 @@ import warnings import zipfile from collections.abc import Mapping, Sequence +from functools import partial from pathlib import Path from pydoc import locate from shutil import copyfile @@ -1254,6 +1255,7 @@ def verify_net_in_out( def _export( converter: Callable, + saver: Callable, parser: ConfigParser, net_id: str, filepath: str, @@ -1268,6 +1270,8 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. + saver: a callable object that accepts the converted model to save, a filepath to save to, meta values + (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. @@ -1307,14 +1311,9 @@ def _export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream=filepath, - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, - ) + meta_values = parser.get().pop("_meta_", None) + saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files) + logger.info(f"exported to file: {filepath}.") @@ -1413,17 +1412,23 @@ def onnx_export( input_shape_ = _get_fake_input_shape(parser=parser) inputs_ = [torch.rand(input_shape_)] - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - onnx_model = convert_to_onnx(model=net, **converter_kwargs_) - onnx.save(onnx_model, filepath_) + + def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None: + onnx.save(onnx_obj, filename_prefix_or_stream) + + _export( + convert_to_onnx, + save_onnx, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, + ) def ckpt_export( @@ -1544,8 +1549,12 @@ def ckpt_export( converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) # Use the given converter to convert a model and save with metadata, config content + + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_torchscript, + save_ts, parser, net_id=net_id_, filepath=filepath_, @@ -1715,8 +1724,11 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_trt, + save_ts, parser, net_id=net_id_, filepath=filepath_, diff --git a/monai/data/utils.py b/monai/data/utils.py index 7a08300abb6..f35c5124d84 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -927,7 +927,7 @@ def compute_shape_offset( corners = in_affine_ @ corners all_dist = corners_out[:-1].copy() corners_out = corners_out[:-1] / corners_out[-1] - out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0) + out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0) offset = None for i in range(corners.shape[1]): min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 3488029a7ae..2ee8c9d3634 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -23,7 +23,7 @@ from monai.data.utils import iter_patch_position from monai.data.wsi_reader import BaseWSIReader, WSIReader from monai.transforms import ForegroundMask, Randomizable, apply_transform -from monai.utils import convert_to_dst_type, ensure_tuple_rep +from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys __all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"] @@ -123,9 +123,9 @@ def _get_label(self, sample: dict): def _get_location(self, sample: dict): if self.center_location: size = self._get_size(sample) - return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))] + return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))) else: - return sample[WSIPatchKeys.LOCATION] + return ensure_tuple(sample[WSIPatchKeys.LOCATION]) def _get_level(self, sample: dict): if self.patch_level is None: diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index a080284e7ca..bd99765348c 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -300,6 +300,7 @@ def sliding_window_inference( # remove padding if image_size smaller than roi_size if any(pad_size): + kwargs.update({"pad_size": pad_size}) for ss, output_i in enumerate(output_image_list): zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] final_slicing: list[slice] = [] diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 07a38d95726..44cde41e5de 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -666,6 +666,7 @@ def __init__( weight: torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_ce: float = 1.0, + label_smoothing: float = 0.0, ) -> None: """ Args: @@ -704,6 +705,9 @@ def __init__( Defaults to 1.0. lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0. Defaults to 1.0. + label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed + by the given factor to reduce overfitting. + Defaults to 0.0. """ super().__init__() @@ -728,7 +732,12 @@ def __init__( batch=batch, weight=dice_weight, ) - self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction) + if pytorch_after(1, 10): + self.cross_entropy = nn.CrossEntropyLoss( + weight=weight, reduction=reduction, label_smoothing=label_smoothing + ) + else: + self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction) self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 3303e89bce1..27a712d308b 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -95,11 +95,11 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: rmask: torch.Tensor if self.dim == 2: - oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float() rmask = self.svls_layer(oh_labels) if self.dim == 3: - oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float() rmask = self.svls_layer(oh_labels) return rmask diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index e56bd465927..047bfd0ab25 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -20,109 +20,108 @@ class GeneralizedDiceScore(CumulativeIterationMetric): - """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in: + """ + Compute the Generalized Dice Score metric between tensors. + This metric is the complement of the Generalized Dice Loss defined in: Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning - loss function for highly unbalanced segmentations. DLMIA 2017. + loss function for highly unbalanced segmentations. DLMIA 2017. - The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first - or batch-first tensors, i.e., CHW[D] or BCHW[D]. + The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D]. Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. Args: - include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the + include_background: Whether to include the background class (assumed to be in channel 0) in the score computation. Defaults to True. - reduction (str, optional): define mode of reduction to the metrics. Available reduction modes: - {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction. - weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform + reduction: Define mode of reduction to the metrics. Available reduction modes: + {``"none"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean"`, ``"sum"`}. Defaults to ``"mean"``. + If "none", will not do reduction. + weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. Raises: - ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}. + ValueError: When the `reduction` is not one of MetricReduction enum. """ def __init__( - self, - include_background: bool = True, - reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, - weight_type: Weight | str = Weight.SQUARE, + self, include_background: bool = True, reduction: str = "mean", weight_type: Weight | str = Weight.SQUARE ) -> None: super().__init__() self.include_background = include_background - reduction_options = [ - "none", - "mean_batch", - "sum_batch", - MetricReduction.NONE, - MetricReduction.MEAN_BATCH, - MetricReduction.SUM_BATCH, - ] - self.reduction = reduction - if self.reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") + self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) + self.sum_over_labels = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] - """Computes the Generalized Dice Score and returns a tensor with its per image values. + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, + y_pred: Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + y: Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + + Returns: + torch.Tensor: Per batch and per class Generalized Dice Score. Raises: - ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. + ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. """ return compute_generalized_dice( - y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type + y_pred=y_pred, + y=y, + include_background=self.include_background, + weight_type=self.weight_type, + sum_over_labels=self.sum_over_labels, ) - def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor: + def aggregate(self) -> torch.Tensor: """ Execute reduction logic for the output of `compute_generalized_dice`. - Args: - reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics. - Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}. - Defaults to ``"mean"``. If "none", will not do reduction. + Returns: + torch.Tensor: Aggregated metric value. + + Raises: + ValueError: If the data to aggregate is not a PyTorch Tensor. """ data = self.get_buffer() if not isinstance(data, torch.Tensor): raise ValueError("The data to aggregate must be a PyTorch Tensor.") - # Validate reduction argument if specified - if reduction is not None: - reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"] - if reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") - # Do metric reduction and return - f, _ = do_metric_reduction(data, reduction or self.reduction) + f, _ = do_metric_reduction(data, self.reduction) return f def compute_generalized_dice( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + weight_type: Weight | str = Weight.SQUARE, + sum_over_labels: bool = False, ) -> torch.Tensor: - """Computes the Generalized Dice Score and returns a tensor with its per image values. + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format + y_pred: Binarized segmentation model output. It should be binarized, in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. - include_background (bool, optional): whether to include score computation on the first channel of the + y: Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. + include_background: Whether to include score computation on the first channel of the predicted output. Defaults to True. - weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to + weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. + sum_over_labels: Whether to sum the numerator and denominator across all labels before the final computation. Returns: - torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. + torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. Raises: - ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, + ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, or `y_pred` and `y` don't have the same shape. """ # Ensure tensors have at least 3 dimensions and have the same shape @@ -158,16 +157,21 @@ def compute_generalized_dice( b[infs] = 0 b[infs] = torch.max(b) - # Compute the weighted numerator and denominator, summing along the class axis - numer = 2.0 * (intersection * w).sum(dim=1) - denom = (denominator * w).sum(dim=1) + # Compute the weighted numerator and denominator, summing along the class axis when sum_over_labels is True + if sum_over_labels: + numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) + denom = (denominator * w).sum(dim=1, keepdim=True) + y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) + else: + numer = 2.0 * (intersection * w) + denom = denominator * w + y_pred_o = y_pred_o # Compute the score generalized_dice_score = numer / denom # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1. # Where denom == 0 but the prediction volume is not 0, score is 0 - y_pred_o = y_pred_o.sum(dim=-1) denom_zeros = denom == 0 generalized_dice_score[denom_zeros] = torch.where( (y_pred_o == 0)[denom_zeros], diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index daa5abdd561..bdecf63168f 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -59,13 +59,12 @@ def __init__( causal (bool, optional): whether to use causal attention. sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only - "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional - parameter size. + parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -109,7 +108,7 @@ def __init__( self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate @@ -152,31 +151,20 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): # calculate query, key, values for all heads in batch and move head forward to be the batch dim b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - q = self.to_q(x) + q = self.input_rearrange(self.to_q(x)) kv = context if context is not None else x _, kv_t, _ = kv.size() - k = self.to_k(kv) - v = self.to_v(kv) + k = self.input_rearrange(self.to_k(kv)) + v = self.input_rearrange(self.to_v(kv)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # - k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - if self.use_flash_attention: x = torch.nn.functional.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, - ).transpose( - 1, 2 - ) # Back to (b, nh, t, hs) + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined @@ -195,6 +183,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 124c00acc67..ac96b077bd3 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Tuple, Union import torch import torch.nn as nn @@ -40,9 +40,11 @@ def __init__( hidden_input_size: int | None = None, causal: bool = False, sequence_length: int | None = None, - rel_pos_embedding: Optional[str] = None, - input_size: Optional[Tuple] = None, - attention_dtype: Optional[torch.dtype] = None, + rel_pos_embedding: str | None = None, + input_size: Tuple | None = None, + attention_dtype: torch.dtype | None = None, + include_fc: bool = True, + use_combined_linear: bool = True, use_flash_attention: bool = False, ) -> None: """ @@ -61,9 +63,10 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -105,9 +108,22 @@ def __init__( self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) - self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) - self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.qkv: Union[nn.Linear, nn.Identity] + self.to_q: Union[nn.Linear, nn.Identity] + self.to_k: Union[nn.Linear, nn.Identity] + self.to_v: Union[nn.Linear, nn.Identity] + + if use_combined_linear: + self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) + self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + else: + self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.qkv = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate @@ -117,6 +133,8 @@ def __init__( self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.include_fc = include_fc + self.use_combined_linear = use_combined_linear self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: @@ -144,8 +162,13 @@ def forward(self, x): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ - output = self.input_rearrange(self.qkv(x)) - q, k, v = output[0], output[1], output[2] + if self.use_combined_linear: + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] + else: + q = self.input_rearrange(self.to_q(x)) + k = self.input_rearrange(self.to_k(x)) + v = self.input_rearrange(self.to_v(x)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) @@ -153,13 +176,8 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, - ).transpose(1, 2) + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale @@ -179,7 +197,9 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) - x = self.out_proj(x) + if self.include_fc: + x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 1cfafb15851..665442b55ed 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -32,8 +32,13 @@ class SpatialAttentionBlock(nn.Module): spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. + norm_num_groups: Number of groups for the group norm layer. + norm_eps: Epsilon for the normalization. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -45,6 +50,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -60,6 +67,8 @@ def __init__( num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 28d9c563ac5..05eb3b07aba 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -37,6 +37,8 @@ def __init__( sequence_length: int | None = None, with_cross_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = True, ) -> None: """ Args: @@ -47,7 +49,9 @@ def __init__( qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. """ @@ -69,6 +73,8 @@ def __init__( save_attn=save_attn, causal=causal, sequence_length=sequence_length, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.norm2 = nn.LayerNorm(hidden_size) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index f62fe432fa2..34e69752eeb 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -77,7 +77,7 @@ resnet200, ) from .segresnet import SegResNet, SegResNetVAE -from .segresnet_ds import SegResNetDS +from .segresnet_ds import SegResNetDS, SegResNetDS2 from .senet import ( SENet, SEnet, @@ -119,6 +119,7 @@ from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder +from .vista3d import VISTA3D, vista3d132 from .vit import ViT from .vitautoenc import ViTAutoEnc from .vnet import VNet diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 35d80e0565f..af191e748bc 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -157,6 +157,10 @@ class Encoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -170,6 +174,9 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -220,6 +227,9 @@ def __init__( num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -243,6 +253,9 @@ def __init__( num_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -291,6 +304,10 @@ class Decoder(nn.Module): attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -305,6 +322,9 @@ def __init__( attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -350,6 +370,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -389,6 +412,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -463,6 +489,10 @@ class AutoencoderKL(nn.Module): with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_checkpoint: if True, use activation checkpoint to save memory. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer in the attention block. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -480,6 +510,9 @@ def __init__( with_decoder_nonlocal_attn: bool = True, use_checkpoint: bool = False, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -499,7 +532,7 @@ def __init__( "`num_channels`." ) - self.encoder = Encoder( + self.encoder: nn.Module = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, channels=channels, @@ -509,8 +542,11 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) - self.decoder = Decoder( + self.decoder: nn.Module = Decoder( spatial_dims=spatial_dims, channels=channels, in_channels=latent_channels, @@ -521,6 +557,9 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, use_convtranspose=use_convtranspose, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, @@ -665,27 +704,18 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.weight"], - old_state_dict[f"{block}.to_k.weight"], - old_state_dict[f"{block}.to_v.weight"], - ], - dim=0, - ) - new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.bias"], - old_state_dict[f"{block}.to_k.bias"], - old_state_dict[f"{block}.to_v.bias"], - ], - dim=0, - ) + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + # old version did not have a projection so set these to the identity new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] @@ -698,5 +728,8 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] - self.load_state_dict(new_state_dict) + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) + self.load_state_dict(new_state_dict, strict=True) diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py new file mode 100644 index 00000000000..308c3a6bcb2 --- /dev/null +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -0,0 +1,92 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.utils import optional_import + +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +_all__ = ["CellSamWrapper"] + + +class CellSamWrapper(torch.nn.Module): + """ + CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything + with an image only decoder, that can be used for segmentation tasks. + + + Args: + auto_resize_inputs: whether to resize inputs before passing to the network. + (usually they need be resized, unless they are already at the expected size) + network_resize_roi: expected input size for the network. + (currently SAM expects 1024x1024) + checkpoint: checkpoint file to load the SAM weights from. + (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) + return_features: whether to return features from SAM encoder + (without using decoder/upsampling to the original input size) + + """ + + def __init__( + self, + auto_resize_inputs=True, + network_resize_roi=(1024, 1024), + checkpoint="sam_vit_b_01ec64.pth", + return_features=False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + self.network_resize_roi = network_resize_roi + self.auto_resize_inputs = auto_resize_inputs + self.return_features = return_features + + if not has_sam: + raise ValueError( + "SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git" + ) + + model = build_sam_vit_b(checkpoint=checkpoint) + + model.prompt_encoder = None + model.mask_decoder = None + + model.mask_decoder = nn.Sequential( + nn.BatchNorm2d(num_features=256), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), + nn.BatchNorm2d(num_features=128), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + ) + + self.model = model + + def forward(self, x): + sh = x.shape[2:] + + if self.auto_resize_inputs: + x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear") + + x = self.model.image_encoder(x) + + if not self.return_features: + x = self.model.mask_decoder(x) + if self.auto_resize_inputs: + x = F.interpolate(x, size=sh, mode="bilinear") + + return x diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index ed3654733d6..8b8813597fe 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -143,6 +143,10 @@ class ControlNet(nn.Module): upcast_attention: if True, upcast attention operations to full precision. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -163,28 +167,29 @@ def __init__( upcast_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " "to be specified when with_conditioning=True." ) if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) + raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.") # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError( - f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" f" channels={channels} and norm_num_groups={norm_num_groups}" ) if len(channels) != len(attention_levels): raise ValueError( - f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"ControlNet expects channels to have the same length as attention_levels, but got " f"channels={channels} and attention_levels={attention_levels}" ) @@ -282,6 +287,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -326,6 +334,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) controlnet_block = Convolution( @@ -441,25 +452,16 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] - - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index a885339d0d8..65d6053acc8 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -67,7 +67,9 @@ class DiffusionUNetTransformerBlock(nn.Module): cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. """ @@ -80,6 +82,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -89,6 +93,8 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) @@ -134,6 +140,11 @@ class SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + """ def __init__( @@ -148,6 +159,9 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -175,6 +189,9 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -529,6 +546,10 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -544,6 +565,9 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -570,6 +594,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -636,7 +663,11 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -656,6 +687,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -688,6 +722,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -745,6 +782,10 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -755,6 +796,9 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -772,6 +816,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( @@ -808,6 +855,10 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -822,6 +873,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -844,6 +898,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, @@ -989,6 +1046,10 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1004,6 +1065,9 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1032,6 +1096,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1116,7 +1183,11 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1136,6 +1207,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1169,6 +1243,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1250,6 +1327,9 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1263,6 +1343,9 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1280,6 +1363,9 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return DownBlock( @@ -1307,6 +1393,9 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1320,6 +1409,9 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return AttnMidBlock( @@ -1329,6 +1421,9 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) @@ -1350,6 +1445,9 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1364,6 +1462,9 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1382,6 +1483,9 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return UpBlock( @@ -1419,9 +1523,13 @@ class DiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1442,6 +1550,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1536,6 +1647,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -1553,6 +1667,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -1587,6 +1704,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) @@ -1714,31 +1834,40 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight") + new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias") + + # fix the cross attention blocks + cross_attention_blocks = [ + k.replace(".out_proj.weight", "") + for k in new_state_dict + if "out_proj.weight" in k and "transformer_blocks" in k + ] + for block in cross_attention_blocks: + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) @@ -1782,6 +1911,9 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1866,6 +1998,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 6430f5fdc9b..1ac5a79ee34 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -11,6 +11,7 @@ from __future__ import annotations +import copy from collections.abc import Callable from typing import Union @@ -23,7 +24,7 @@ from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils import UpsampleMode, has_option -__all__ = ["SegResNetDS"] +__all__ = ["SegResNetDS", "SegResNetDS2"] def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None): @@ -425,3 +426,128 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]: return self._forward(x) + + +class SegResNetDS2(SegResNetDS): + """ + SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D + `_. + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + act: activation type and arguments. Defaults to ``RELU``. + norm: feature normalization type and arguments. Defaults to ``BATCH``. + blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``. + blocks_up: number of upsample blocks (optional). + dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level. + At dsdepth==1,only a single output is returned. + preprocess: optional callable function to apply before the model's forward pass + resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring + image spacing into an approximately isotropic space. + Otherwise, by default, the kernel size and downsampling is always isotropic. + + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + act: tuple | str = "relu", + norm: tuple | str = "batch", + blocks_down: tuple = (1, 2, 2, 4), + blocks_up: tuple | None = None, + dsdepth: int = 1, + preprocess: nn.Module | Callable | None = None, + upsample_mode: UpsampleMode | str = "deconv", + resolution: tuple | None = None, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=init_filters, + in_channels=in_channels, + out_channels=out_channels, + act=act, + norm=norm, + blocks_down=blocks_down, + blocks_up=blocks_up, + dsdepth=dsdepth, + preprocess=preprocess, + upsample_mode=upsample_mode, + resolution=resolution, + ) + + self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers]) + + def forward( # type: ignore + self, x: torch.Tensor, with_point: bool = True, with_label: bool = True + ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: + """ + Args: + x: input tensor. + with_point: if true, return the point branch output. + with_label: if true, return the label branch output. + """ + if self.preprocess is not None: + x = self.preprocess(x) + + if not self.is_valid_shape(x): + raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}") + + x_down = self.encoder(x) + + x_down.reverse() + x = x_down.pop(0) + + if len(x_down) == 0: + x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)] + + outputs: list[torch.Tensor] = [] + outputs_auto: list[torch.Tensor] = [] + x_ = x.clone() + if with_point: + i = 0 + for level in self.up_layers: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs.append(level["head"](x)) + i = i + 1 + + outputs.reverse() + x = x_ + if with_label: + i = 0 + for level in self.up_layers_auto: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs_auto.append(level["head"](x)) + i = i + 1 + + outputs_auto.reverse() + + return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto + + def set_auto_grad(self, auto_freeze=False, point_freeze=False): + """ + Args: + auto_freeze: if true, freeze the image encoder and the auto-branch. + point_freeze: if true, freeze the image encoder and the point-branch. + """ + for param in self.encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + + for param in self.up_layers_auto.parameters(): + param.requires_grad = not auto_freeze + + for param in self.up_layers.parameters(): + param.requires_grad = not point_freeze diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py index d5794a9227f..cc8909194a1 100644 --- a/monai/networks/nets/spade_autoencoderkl.py +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module): label_nc: number of semantic channels for SPADE normalisation. with_nonlocal_attn: if True use non-local attention block. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -152,6 +156,9 @@ def __init__( label_nc: int, with_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -200,6 +207,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -243,6 +253,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -331,6 +344,9 @@ def __init__( with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -360,6 +376,9 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.decoder = SPADEDecoder( spatial_dims=spatial_dims, @@ -373,6 +392,9 @@ def __init__( label_nc=label_nc, with_nonlocal_attn=with_decoder_nonlocal_attn, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index 75d1687df34..a9609b1d396 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -325,6 +325,10 @@ class SPADEAttnUpBlock(nn.Module): resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. spade_intermediate_channels: number of intermediate channels for SPADE block layer + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -342,6 +346,9 @@ def __init__( resblock_updown: bool = False, num_head_channels: int = 1, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -371,6 +378,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -457,6 +467,8 @@ class SPADECrossAttnUpBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism. + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -477,6 +489,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -510,6 +525,9 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -592,6 +610,9 @@ def get_spade_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return SPADEAttnUpBlock( @@ -608,6 +629,9 @@ def get_spade_up_block( resblock_updown=resblock_updown, num_head_channels=num_head_channels, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return SPADECrossAttnUpBlock( @@ -627,6 +651,7 @@ def get_spade_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) else: return SPADEUpBlock( @@ -667,9 +692,11 @@ class SPADEDiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - spade_intermediate_channels: number of intermediate channels for SPADE block layer + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -691,6 +718,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -783,6 +813,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -799,6 +832,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -834,6 +870,7 @@ def __init__( upcast_attention=upcast_attention, label_nc=label_nc, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index 1af725abdaa..3a278c112aa 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -51,6 +51,10 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -62,6 +66,9 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -86,6 +93,9 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] @@ -133,25 +143,15 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] - - # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] - for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn.to_q.weight"], - old_state_dict[f"{block}.attn.to_k.weight"], - old_state_dict[f"{block}.attn.to_v.weight"], - ], - dim=0, - ) + new_state_dict[k] = old_state_dict.pop(k) # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 - for k in old_state_dict: + for k in list(old_state_dict.keys()): if "norm2" in k: - new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k] + new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k) if "norm3" in k: - new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k] - + new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py new file mode 100644 index 00000000000..9148e365425 --- /dev/null +++ b/monai/networks/nets/vista3d.py @@ -0,0 +1,943 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +import monai +from monai.networks.blocks import MLPBlock, UnetrBasicBlock +from monai.networks.nets import SegResNetDS2 +from monai.transforms.utils import convert_points_to_disc +from monai.transforms.utils import keep_merge_components_with_points as lcc +from monai.transforms.utils import sample_points_from_label +from monai.utils import optional_import, unsqueeze_left, unsqueeze_right + +rearrange, _ = optional_import("einops", name="rearrange") + +__all__ = ["VISTA3D", "vista3d132"] + + +def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): + """ + Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. + The model treats class index larger than 132 as zero-shot. + + Args: + encoder_embed_dim: hidden dimension for encoder. + in_channels: input channel number. + """ + segresnet = SegResNetDS2( + in_channels=in_channels, + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=encoder_embed_dim, + init_filters=encoder_embed_dim, + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) + vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) + return vista + + +class VISTA3D(nn.Module): + """ + VISTA3D based on: + `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography + `_. + + Args: + image_encoder: image encoder backbone for feature extraction. + class_head: class head used for class index based segmentation + point_head: point head used for interactive segmetnation + """ + + def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module): + super().__init__() + self.image_encoder = image_encoder + self.class_head = class_head + self.point_head = point_head + self.image_embeddings = None + self.auto_freeze = False + self.point_freeze = False + self.NINF_VALUE = -9999 + self.PINF_VALUE = 9999 + + def update_slidingwindow_padding( + self, + pad_size: list | None, + labels: torch.Tensor | None, + prev_mask: torch.Tensor | None, + point_coords: torch.Tensor | None, + ): + """ + Image has been padded by sliding window inferer. + The related padding need to be performed outside of slidingwindow inferer. + + Args: + pad_size: padding size passed from sliding window inferer. + labels: image label ground truth. + prev_mask: previous segmentation mask. + point_coords: point click coordinates. + """ + if pad_size is None: + return labels, prev_mask, point_coords + if labels is not None: + labels = F.pad(labels, pad=pad_size, mode="constant", value=0) + if prev_mask is not None: + prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0) + if point_coords is not None: + point_coords = point_coords + torch.tensor( + [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device + ) + return labels, prev_mask, point_coords + + def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: + """Get number of foreground classes based on class and point prompt.""" + if class_vector is None: + if point_coords is None: + raise ValueError("class_vector and point_coords cannot be both None.") + return point_coords.shape[0] + else: + return class_vector.shape[0] + + def convert_point_label( + self, + point_label: torch.Tensor, + label_set: Sequence[int] | None = None, + special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128), + ): + """ + Convert point label based on its class prompt. For special classes defined in special index, + the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those + classes with ambiguous classes. + + Args: + point_label: the point label tensor, [B, N]. + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + special_index: the special class index that needs to be converted. + """ + if label_set is None: + return point_label + if not point_label.shape[0] == len(label_set): + raise ValueError("point_label and label_set must have the same length.") + + for i in range(len(label_set)): + if label_set[i] in special_index: + for j in range(len(point_label[i])): + point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j] + return point_label + + def sample_points_patch_val( + self, + labels: torch.Tensor, + patch_coords: Sequence[slice], + label_set: Sequence[int], + use_center: bool = True, + mapped_label_set: Sequence[int] | None = None, + max_ppoint: int = 1, + max_npoint: int = 0, + ): + """ + Sample points for patch during sliding window validation. Only used for point only validation. + + Args: + labels: shape [1, 1, H, W, D]. + patch_coords: a sequence of sliding window slice objects. + label_set: local index, must match values in labels. + use_center: sample points from the center. + mapped_label_set: global index, it is used to identify special classes and is the global index + for the sampled points. + max_ppoint/max_npoint: positive points and negative points to sample. + """ + point_coords, point_labels = sample_points_from_label( + labels[patch_coords], + label_set, + max_ppoint=max_ppoint, + max_npoint=max_npoint, + device=labels.device, + use_center=use_center, + ) + point_labels = self.convert_point_label(point_labels, mapped_label_set) + return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1)) + + def update_point_to_patch( + self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor + ): + """ + Update point_coords with respect to patch coords. + If point is outside of the patch, remove the coordinates and set label to -1. + + Args: + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. + point_coords: point coordinates, [B, N, 3]. + point_labels: point labels, [B, N]. + """ + patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop] + patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start] + # update point coords + patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2) + patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2) + # [1 N 1] + indices = torch.logical_and( + ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2) + ) + # check if it's within patch coords + point_coords = point_coords.clone() - patch_starts_tensor + point_labels = point_labels.clone() + if indices.any(): + point_labels[~indices] = -1 + point_coords[~indices] = 0 + # also remove padded points, mainly used for inference. + not_pad_indices = (point_labels != -1).any(0) + point_coords = point_coords[:, not_pad_indices] + point_labels = point_labels[:, not_pad_indices] + return point_coords, point_labels + return None, None + + def connected_components_combine( + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + thred: float = 0.5, + ): + """ + Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks + from a single image patch. + Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing. + mapping_index represents the correspondence between B and B1. + For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed + region in point clicks must be updated by the lcc function. + Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added. + + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + thred: the threshold to convert logits to binary. + """ + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + _logits = logits[mapping_index] + inside = [] + for i in range(_logits.shape[0]): + inside.append( + np.any( + [ + _logits[i, 0, p[0], p[1], p[2]].item() > 0 + for p in point_coords[i].cpu().numpy().round().astype(int) + ] + ) + ) + inside_tensor = torch.tensor(inside).to(logits.device) + nan_mask = torch.isnan(_logits) + # _logits are converted to binary [B1, 1, H, W, D] + _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() + pos_region = point_logits.sigmoid() > thred + diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region) + diff_neg = torch.logical_and((_logits > thred), ~pos_region) + cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels) + # cc is the region that can be updated by point_logits. + cc = cc.to(logits.device) + # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, + # only remove unconnected positive region. + uc_pos_region = torch.logical_and(pos_region, ~cc) + fill_mask = torch.logical_and(nan_mask, uc_pos_region) + if fill_mask.any(): + # fill in the mean negative value + point_logits[fill_mask] = -1 + # replace logits nan value and cc with point_logits + cc = torch.logical_or(nan_mask, cc).to(logits.dtype) + logits[mapping_index] *= 1 - cc + logits[mapping_index] += cc * point_logits + return logits + + def gaussian_combine( + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + radius: int | None = None, + ): + """ + Combine point results with auto results using gaussian. + + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + radius: gaussian ball radius. + """ + if radius is None: + radius = min(point_logits.shape[-3:]) // 5 # empirical value 5 + weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum( + 1, keepdims=True + ) + weight[weight < 0] = 0 + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + logits[mapping_index] *= weight + logits[mapping_index] += (1 - weight) * point_logits + return logits + + def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): + """ + Freeze auto-branch or point-branch. + + Args: + auto_freeze: whether to freeze the auto branch. + point_freeze: whether to freeze the point branch. + """ + if auto_freeze != self.auto_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.class_head.parameters(): + param.requires_grad = not auto_freeze + self.auto_freeze = auto_freeze + + if point_freeze != self.point_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.point_head.parameters(): + param.requires_grad = not point_freeze + self.point_freeze = point_freeze + + def forward( + self, + input_images: torch.Tensor, + point_coords: torch.Tensor | None = None, + point_labels: torch.Tensor | None = None, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + patch_coords: Sequence[slice] | None = None, + labels: torch.Tensor | None = None, + label_set: Sequence[int] | None = None, + prev_mask: torch.Tensor | None = None, + radius: int | None = None, + val_point_sampler: Callable | None = None, + transpose: bool = False, + **kwargs, + ): + """ + The forward function for VISTA3D. We only support single patch in training and inference. + One exception is allowing sliding window batch size > 1 for automatic segmentation only case. + B represents number of objects, N represents number of points for each objects. + + Args: + input_images: [1, 1, H, W, D] + point_coords: [B, N, 3] + point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. + 2/3 means negative/postive ponits for special supported class like tumor. + class_vector: [B, 1], the global class index. + prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if + the points are for zero-shot or supported class. When class_vector and point_coords are both + provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] + will be considered novel class. + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. + labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize]. + This is the transposed raw output from sliding_window_inferer before any postprocessing. + When user click points to perform auto-results correction, this can be the auto-results. + radius: single float value controling the gaussian blur when combining point and auto results. + The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. + val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. + transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from + sliding window inferer/point inferer. + """ + labels, prev_mask, point_coords = self.update_slidingwindow_padding( + kwargs.get("pad_size", None), labels, prev_mask, point_coords + ) + image_size = input_images.shape[-3:] + device = input_images.device + if point_coords is None and class_vector is None: + return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) + + bs = self.get_foreground_class_count(class_vector, point_coords) + if patch_coords is not None: + # if during validation and perform enable based point-validation. + if labels is not None and label_set is not None: + # if labels is not None, sample from labels for each patch. + if val_point_sampler is None: + # TODO: think about how to refactor this part. + val_point_sampler = self.sample_points_patch_val + point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set) + if prompt_class[0].item() == 0: # type: ignore + point_labels[0] = -1 # type: ignore + labels, prev_mask = None, None + elif point_coords is not None: + # If not performing patch-based point only validation, use user provided click points for inference. + # the point clicks is in original image space, convert it to current patch-coordinate space. + point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore + + if point_coords is not None and point_labels is not None: + # remove points that used for padding purposes (point_label = -1) + mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool) + if mapping_index.any(): + point_coords = point_coords[mapping_index] + point_labels = point_labels[mapping_index] + if prompt_class is not None: + prompt_class = prompt_class[mapping_index] + else: + if self.auto_freeze or (class_vector is None and patch_coords is None): + # if auto_freeze, point prompt must exist to allow loss backward + # in training, class_vector and point cannot both be None due to loss.backward() + mapping_index.fill_(True) + else: + point_coords, point_labels = None, None + + if point_coords is None and class_vector is None: + return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + + if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: + out, out_auto = self.image_embeddings, None + else: + out, out_auto = self.image_encoder( + input_images, with_point=point_coords is not None, with_label=class_vector is not None + ) + # release memory + input_images = None # type: ignore + + # force releasing memories that set to None + torch.cuda.empty_cache() + if class_vector is not None: + logits, _ = self.class_head(out_auto, class_vector) + if point_coords is not None: + point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if patch_coords is None: + logits = self.gaussian_combine( + logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore + ) + else: + # during validation use largest component + logits = self.connected_components_combine( + logits, point_logits, point_coords, point_labels, mapping_index # type: ignore + ) + else: + logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype) + logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if prev_mask is not None and patch_coords is not None: + logits = self.connected_components_combine( + prev_mask[patch_coords].transpose(1, 0).to(logits.device), + logits[mapping_index], + point_coords, # type: ignore + point_labels, # type: ignore + mapping_index, + ) + if kwargs.get("keep_cache", False) and class_vector is None: + self.image_embeddings = out.detach() + if transpose: + logits = logits.transpose(1, 0) + return logits + + +class PointMappingSAM(nn.Module): + def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132): + """Interactive point head used for VISTA3D. + Adapted from segment anything: + `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + feature_size: feature channel from encoder. + max_prompt: max prompt number in each forward iteration. + n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings. + last_supported: number of classes the model support, this value should match the trained model weights. + """ + super().__init__() + transformer_dim = feature_size + self.max_prompt = max_prompt + self.feat_downsample = nn.Sequential( + nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1), + nn.InstanceNorm3d(feature_size), + nn.GELU(), + nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(feature_size), + ) + + self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1) + + self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4) + self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2) + self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)]) + self.not_a_point_embed = nn.Embedding(1, transformer_dim) + self.special_class_embed = nn.Embedding(1, transformer_dim) + self.mask_tokens = nn.Embedding(1, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm3d(transformer_dim), + nn.GELU(), + nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1), + ) + + self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3) + # class embedding + self.n_classes = n_classes + self.last_supported = last_supported + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.zeroshot_embed = nn.Embedding(1, transformer_dim) + self.supported_embed = nn.Embedding(1, transformer_dim) + + def forward( + self, + out: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + class_vector: torch.Tensor | None = None, + ): + """Args: + out: feature from encoder, [1, C, H, W, C] + point_coords: point coordinates, [B, N, 3] + point_labels: point labels, [B, N] + class_vector: class prompts, [B] + """ + # downsample out + out_low = self.feat_downsample(out) + out_shape = tuple(out.shape[-3:]) + # release memory + out = None # type: ignore + torch.cuda.empty_cache() + # embed points + points = point_coords + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore + point_embedding[point_labels == -1] = 0.0 + point_embedding[point_labels == -1] += self.not_a_point_embed.weight + point_embedding[point_labels == 0] += self.point_embeddings[0].weight + point_embedding[point_labels == 1] += self.point_embeddings[1].weight + point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight + point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight + output_tokens = self.mask_tokens.weight + + output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1) + if class_vector is None: + tokens_all = torch.cat( + ( + output_tokens, + point_embedding, + self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1), + ), + dim=1, + ) + # tokens_all = torch.cat((output_tokens, point_embedding), dim=1) + else: + class_embeddings = [] + for i in class_vector: + if i > self.last_supported: + class_embeddings.append(self.zeroshot_embed.weight) + else: + class_embeddings.append(self.supported_embed.weight) + tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1) + # cross attention + masks = [] + max_prompt = self.max_prompt + for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))): + # remove variables in previous for loops to save peak memory for self.transformer + src, upscaled_embedding, hyper_in = None, None, None + torch.cuda.empty_cache() + idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0])) + tokens = tokens_all[idx[0] : idx[1]] + src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0) + pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0) + b, c, h, w, d = src.shape + hs, src = self.transformer(src, pos_src, tokens) + mask_tokens_out = hs[:, :1, :] + hyper_in = self.output_hypernetworks_mlps(mask_tokens_out) + src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore + upscaled_embedding = self.output_upscaling(src) + b, c, h, w, d = upscaled_embedding.shape + mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d) + masks.append(mask.view(-1, 1, h, w, d)) + + return torch.vstack(masks) + + +class ClassMappingClassify(nn.Module): + """Class head that performs automatic segmentation based on class vector.""" + + def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True): + """Args: + n_classes: maximum number of class embedding. + feature_size: class embedding size. + use_mlp: use mlp to further map class embedding. + """ + super().__init__() + self.use_mlp = use_mlp + if use_mlp: + self.mlp = nn.Sequential( + nn.Linear(feature_size, feature_size), + nn.InstanceNorm1d(1), + nn.GELU(), + nn.Linear(feature_size, feature_size), + ) + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.image_post_mapping = nn.Sequential( + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + ) + + def forward(self, src: torch.Tensor, class_vector: torch.Tensor): + b, c, h, w, d = src.shape + src = self.image_post_mapping(src) + class_embedding = self.class_embeddings(class_vector) + if self.use_mlp: + class_embedding = self.mlp(class_embedding) + # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. + masks = [] + for i in range(b): + mask = class_embedding @ src[[i]].view(1, c, h * w * d) + masks.append(mask.view(-1, 1, h, w, d)) + + return torch.cat(masks, 1), class_embedding + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: tuple | str = "relu", + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + depth: number of layers in the transformer. + embedding_dim: the channel dimension for the input embeddings. + num_heads: the number of heads for multihead attention. Must divide embedding_dim. + mlp_dim: the channel dimension internal to the MLP block. + activation: the activation to use in the MLP block. + attention_downsample_rate: the rate at which to downsample the image before projecting. + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + image_embedding: image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe: the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding: the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding. + torch.Tensor: the processed image_embedding. + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: tuple | str = "relu", + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + mlp_dim: the hidden dimension of the mlp block. + activation: the activation of the mlp block. + skip_first_layer_pe: skip the PE on the first layer. + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d") + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + downsample_rate: the rate at which to downsample the image before projecting. + """ + + def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + if not self.internal_dim % num_heads == 0: + raise ValueError("num_heads must divide embedding_dim.") + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + # B x N_heads x N_tokens x C_per_head + return x.transpose(1, 2) + + def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + # B x N_tokens x C + return x.reshape(b, n_tokens, n_heads * c_per_head) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`. + + Args: + num_pos_feats: the number of positional encoding features. + scale: the scale of the positional encoding. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats))) + + def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + # [bs=1,N=2,2] @ [2,128] + # [bs=1, N=2, 128] + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + # [bs=1, N=2, 128+128=256] + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w, d = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w, d), device=device, dtype=torch.float32) + x_embed = grid.cumsum(dim=0) - 0.5 + y_embed = grid.cumsum(dim=1) - 0.5 + z_embed = grid.cumsum(dim=2) - 0.5 + x_embed = x_embed / h + y_embed = y_embed / w + z_embed = z_embed / d + pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) + # C x H x W + return pe.permute(3, 0, 1, 2) + + def forward_with_coords( + self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int] + ) -> torch.torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[0] + coords[:, :, 1] = coords[:, :, 1] / image_size[1] + coords[:, :, 2] = coords[:, :, 2] / image_size[2] + # B x N x C + return self._pe_encoding(coords.to(torch.float)) + + +class MLP(nn.Module): + """ + Multi-layer perceptron. This class is only used for `PointMappingSAM`. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + input_dim: the input dimension. + hidden_dim: the hidden dimension. + output_dim: the output dimension. + num_layers: the number of layers. + sigmoid_output: whether to apply a sigmoid activation to the output. + """ + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid_output = sigmoid_output + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 6a974342159..f301c2dd5cf 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -822,7 +822,7 @@ def _onnx_trt_compile( output_names = [] if not output_names else output_names # set up the TensorRT builder - torch_tensorrt.set_device(device) + torch.cuda.set_device(device) logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) @@ -931,7 +931,7 @@ def convert_to_trt( warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.") device = device if device else 0 - target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0") + target_device = torch.device(f"cuda:{device}") convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] @@ -986,7 +986,7 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): ir_model, inputs=input_placeholder, enabled_precisions=convert_precision, - device=target_device, + device=torch_tensorrt.Device(f"cuda:{device}"), ir="torchscript", **kwargs, ) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ef1da2d855c..95484437688 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -688,6 +688,7 @@ weighted_patch_samples, zero_margins, ) +from .utils_morphological_ops import dilate, erode from .utils_pytorch_numpy_unification import ( allclose, any_np_pt, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index e0ecc127f23..7c0e8f7123f 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -86,7 +86,7 @@ def switch_endianness(data, new="<"): if new not in ("<", ">"): raise NotImplementedError(f"Not implemented option new={new}.") if current_ != new: - data = data.byteswap().newbyteorder(new) + data = data.byteswap().view(data.dtype.newbyteorder(new)) elif isinstance(data, tuple): data = tuple(switch_endianness(x, new) for x in data) elif isinstance(data, list): diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index add4e7f5eaa..22726f06a5c 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -373,7 +373,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l if output_shape is None: corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1)) corners = transform[:-1, :-1] @ corners # type: ignore - output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) + output_shape = np.asarray(np.ptp(corners, axis=1) + 0.5, dtype=int) else: output_shape = np.asarray(output_shape, dtype=int) shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist()) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3d09cea5450..15c2499a73e 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -203,8 +203,8 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState """ if seed is not None: - _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed - _seed = _seed % MAX_SEED + _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) + _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d8461d927b4..7027c07d67b 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,6 +22,7 @@ import numpy as np import torch +from torch import Tensor import monai from monai.config import DtypeLike, IndexSelection @@ -30,6 +31,7 @@ from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform, Transform, apply_transform +from monai.transforms.utils_morphological_ops import erode from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, ascontiguousarray, @@ -65,6 +67,8 @@ min_version, optional_import, pytorch_after, + unsqueeze_left, + unsqueeze_right, ) from monai.utils.enums import TransformBackends from monai.utils.type_conversion import ( @@ -103,6 +107,9 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", + "keep_merge_components_with_points", + "keep_components_with_positive_points", + "convert_points_to_disc", "remove_small_objects", "img_bounds", "in_bounds", @@ -1172,6 +1179,227 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] +def keep_merge_components_with_points( + img_pos: NdarrayTensor, + img_neg: NdarrayTensor, + point_coords: NdarrayTensor, + point_labels: NdarrayTensor, + pos_val: Sequence[int] = (1, 3), + neg_val: Sequence[int] = (0, 2), + margins: int = 3, +) -> NdarrayTensor: + """ + Keep connected regions of img_pos and img_neg that include the positive points and + negative points separately. The function is used for merging automatic results with interactive + results in VISTA3D. + + Args: + img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image. + img_neg: same format as img_pos but corresponds to negative points. + pos_val: positive point label values. + neg_val: negative point label values. + point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points. + point_labels: the label of each point, shape [B, N]. + margins: include points outside of the region but within the margin. + """ + + cucim_skimage, has_cucim = optional_import("cucim.skimage") + + use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu") + if use_cp: + img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore + img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore + label = cucim_skimage.measure.label + lib = cp + else: + if not has_measure: + raise RuntimeError("skimage.measure required.") + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) + # for skimage.measure.label, the input must be bool type + if img_pos_.dtype != bool or img_neg_.dtype != bool: + raise ValueError("img_pos and img_neg must be bool type.") + label = measure.label + lib = np + + features_pos, _ = label(img_pos_, connectivity=3, return_num=True) + features_neg, _ = label(img_neg_, connectivity=3, return_num=True) + + outs = np.zeros_like(img_pos_) + for bs in range(point_coords.shape[0]): + for i, p in enumerate(point_coords[bs]): + if point_labels[bs, i] in pos_val: + features = features_pos + elif point_labels[bs, i] in neg_val: + features = features_neg + else: + # if -1 padding point, skip + continue + for margin in range(margins): + if isinstance(p, np.ndarray): + x, y, z = np.round(p).astype(int).tolist() + else: + x, y, z = p.float().round().int().tolist() + l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) + t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) + f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) + if (features[bs, 0, l:r, t:d, f:b] > 0).any(): + index = features[bs, 0, l:r, t:d, f:b].max() + outs[[bs]] += lib.isin(features[[bs]], index) + break + outs[outs > 1] = 1 + return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] + + +def keep_components_with_positive_points( + img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor +) -> torch.Tensor: + """ + Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove + regions without positive points. + Args: + img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value. + point_coords: [B, N, 3]. Point click coordinates + point_labels: [B, N]. Point click labels. + """ + if not has_measure: + raise RuntimeError("skimage.measure required.") + outs = torch.zeros_like(img) + for c in range(len(point_coords)): + if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()): + # skip if no positive points. + continue + coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist() + not_nan_mask = ~torch.isnan(img[0, c]) + img_ = torch.nan_to_num(img[0, c] > 0, 0) + img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore + label = measure.label + features = label(img_, connectivity=3) + pos_mask = torch.from_numpy(img_).to(img.device) > 0 + # if num features less than max desired, nothing to do. + features = torch.from_numpy(features).to(img.device) + # generate a map with all pos points + idx = [] + for p in coords: + idx.append(features[round(p[0]), round(p[1]), round(p[2])].item()) + idx = list(set(idx)) + for i in idx: + if i == 0: + continue + outs[0, c] += features == i + outs = outs > 0 + # find negative mean value + fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean() + img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in + return img + + +def convert_points_to_disc( + image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False +): + """ + Convert a 3D point coordinates into image mask. The returned mask has the same spatial + size as `image_size` while the batch dimension is the same as 'point' batch dimension. + The point is converted to a mask ball with radius defined by `radius`. The output + contains two channels each for negative (first channel) and positive points. + + Args: + image_size: The output size of the converted mask. It should be a 3D tuple. + point: [B, N, 3], 3D point coordinates. + point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points. + radius: disc ball radius size. + disc: If true, use regular disc, other use gaussian. + """ + masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device) + _array = [ + torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3) + ] + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2]) + # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] + coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) + coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) + for b, n in np.ndindex(*point.shape[:2]): + point_bn = unsqueeze_right(point[b, n], 4) + if point_label[b, n] > -1: + channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 + pow_diff = torch.pow(coords[b, channel] - point_bn, 2) + if disc: + masks[b, channel] += pow_diff.sum(0) < radius**2 + else: + masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) + return masks + + +def sample_points_from_label( + labels: Tensor, + label_set: Sequence[int], + max_ppoint: int = 1, + max_npoint: int = 0, + device: torch.device | str | None = "cpu", + use_center: bool = False, +): + """Sample points from labels. + + Args: + labels: [1, 1, H, W, D] + label_set: local index, must match values in labels. + max_ppoint: maximum positive point samples. + max_npoint: maximum negative point samples. + device: returned tensor device. + use_center: whether to sample points from center. + + Returns: + point: point coordinates of [B, N, 3]. B equals to the length of label_set. + point_label: [B, N], always 0 for negative, 1 for positive. + """ + if not labels.shape[0] == 1: + raise ValueError("labels must have batch size 1.") + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + labels = labels[0, 0] + unique_labels = labels.unique().cpu().numpy().tolist() + _point = [] + _point_label = [] + for id in label_set: + if id in unique_labels: + plabels = labels == int(id) + nlabels = ~plabels + _plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0]) + plabelpoints = torch.nonzero(_plabels).to(device) + if len(plabelpoints) == 0: + plabelpoints = torch.nonzero(plabels).to(device) + nlabelpoints = torch.nonzero(nlabels).to(device) + num_p = min(len(plabelpoints), max_ppoint) + num_n = min(len(nlabelpoints), max_npoint) + pad = max_ppoint + max_npoint - num_p - num_n + if use_center: + pmean = plabelpoints.float().mean(0) + pdis = ((plabelpoints - pmean) ** 2).sum(-1) + _, sorted_indices_tensor = torch.sort(pdis) + sorted_indices = sorted_indices_tensor.cpu().tolist() + else: + sorted_indices = list(range(len(plabelpoints))) + random.shuffle(sorted_indices) + _point.append( + torch.stack( + [plabelpoints[sorted_indices[i]] for i in range(num_p)] + + random.choices(nlabelpoints, k=num_n) + + [torch.tensor([0, 0, 0], device=device)] * pad + ) + ) + _point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device)) + else: + # pad the background labels + _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) + _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) + point = torch.stack(_point) + point_label = torch.stack(_point_label) + + return point, point_label + + def remove_small_objects( img: NdarrayTensor, min_size: int = 64, diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/transforms/utils_morphological_ops.py similarity index 99% rename from monai/apps/generation/maisi/utils/morphological_ops.py rename to monai/transforms/utils_morphological_ops.py index 14786d60a28..b3134c18656 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/transforms/utils_morphological_ops.py @@ -20,6 +20,8 @@ from monai.config import NdarrayOrTensor from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep +__all__ = ["erode", "dilate"] + def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: """ diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 020d99af165..98b75cff76a 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.max(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: @@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.min(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor: diff --git a/monai/utils/module.py b/monai/utils/module.py index 4d28f8d9869..78087aef843 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -13,7 +13,6 @@ import enum import functools -import importlib.util import os import pdb import re @@ -209,10 +208,11 @@ def load_submodules( ): if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None: try: + mod = import_module(name) mod_spec = importer.find_spec(name) # type: ignore if mod_spec and mod_spec.loader: - mod = importlib.util.module_from_spec(mod_spec) - mod_spec.loader.exec_module(mod) + loader = mod_spec.loader + loader.exec_module(mod) submodules.append(mod) except OptionalImportError: pass # could not import the optional deps., they are ignored @@ -564,7 +564,7 @@ def version_leq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("pkg_resources", name="packaging") + pkging, has_ver = optional_import("packaging.Version") if has_ver: try: return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs)) @@ -591,7 +591,8 @@ def version_geq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("pkg_resources", name="packaging") + pkging, has_ver = optional_import("packaging.Version") + if has_ver: try: return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs)) @@ -629,7 +630,7 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st if current_ver_string is None: _env_var = os.environ.get("PYTORCH_VER", "") current_ver_string = _env_var if _env_var else torch.__version__ - ver, has_ver = optional_import("pkg_resources", name="parse_version") + ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) diff --git a/requirements-dev.txt b/requirements-dev.txt index 72ba2100933..9aad0804e63 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -58,4 +58,4 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0 -git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd +git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 diff --git a/requirements-min.txt b/requirements-min.txt index a091ef05681..21cf9d5e5c5 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -4,3 +4,4 @@ setuptools>=50.3.0,<66.0.0,!=60.6.0 ; python_version < "3.12" setuptools>=70.2.0; python_version >= "3.12" coverage>=5.5 parameterized +packaging diff --git a/requirements.txt b/requirements.txt index aae455f58cb..e184322c135 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch>=1.9 -numpy>=1.20,<=1.26.0 +numpy>=1.24,<2.0 diff --git a/setup.cfg b/setup.cfg index 202e7b0e243..1ce4a3f34c8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ setup_requires = ninja install_requires = torch>=1.9 - numpy>=1.20 + numpy>=1.24,<2.0 [options.extras_require] all = @@ -137,6 +137,8 @@ pyyaml = pyyaml fire = fire +packaging = + packaging jsonschema = jsonschema pynrrd = @@ -160,11 +162,13 @@ pynvml = nvidia-ml-py # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = -# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded + # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0 +# segment-anything = +# segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/setup.py b/setup.py index b90d9d09761..576743c1f72 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ import sys import warnings -import pkg_resources +from packaging import version from setuptools import find_packages, setup import versioneer @@ -40,7 +40,7 @@ BUILD_CUDA = FORCE_CUDA or (torch.cuda.is_available() and (CUDA_HOME is not None)) - _pt_version = pkg_resources.parse_version(torch.__version__).release + _pt_version = version.parse(torch.__version__).release if _pt_version is None or len(_pt_version) < 3: raise AssertionError("unknown torch version") TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2]) diff --git a/tests/min_tests.py b/tests/min_tests.py index 3a143df84b0..f80d06f5d30 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -209,6 +209,8 @@ def run_testsuit(): "test_zarr_avg_merger", "test_perceptual_loss", "test_ultrasound_confidence_map_transform", + "test_vista3d_utils", + "test_vista3d_transforms", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 01dc0448704..107114861cd 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -127,7 +127,7 @@ def test_loading_mmar(self, item): in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), - pos_embed="conv", + proj_type="conv", hidden_size=768, mlp_dim=3072, ) diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 392a3d7db20..0e9f427fb64 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -16,16 +16,13 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion tqdm, has_tqdm = optional_import("tqdm", name="tqdm") _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -79,7 +76,6 @@ CASES = CASES_NO_ATTENTION -@unittest.skipUnless(has_generative, "monai-generative required") class TestAutoencoderKlMaisi(unittest.TestCase): @parameterized.expand(CASES) diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py new file mode 100644 index 00000000000..2f1ee2b9011 --- /dev/null +++ b/tests/test_cell_sam_wrapper.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.cell_sam_wrapper import CellSamWrapper +from monai.utils import optional_import + +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +device = "cuda" if torch.cuda.is_available() else "cpu" +TEST_CASE_CELLSEGWRAPPER = [] +for dims in [128, 256, 512, 1024]: + test_case = [ + {"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None}, + (1, 3, *([dims] * 2)), + (1, 3, *([dims] * 2)), + ] + TEST_CASE_CELLSEGWRAPPER.append(test_case) + + +@unittest.skipUnless(has_sam, "Requires SAM installation") +class TestResNetDS(unittest.TestCase): + + @parameterized.expand(TEST_CASE_CELLSEGWRAPPER) + def test_shape(self, input_param, input_shape, expected_shape): + net = CellSamWrapper(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + def test_ill_arg0(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device) + net(torch.randn([1, 3, 256, 256]).to(device)) + + def test_ill_arg1(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device) + net(torch.randn([1, 3, 1024, 1024]).to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index e04444e988c..fd7745245c5 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -22,7 +22,7 @@ _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background -TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1) +TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) { "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device), "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device), @@ -32,7 +32,7 @@ ] # remove background -TEST_CASE_2 = [ # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) +TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background) { "y_pred": torch.tensor( [ @@ -48,11 +48,11 @@ ), "include_background": False, }, - [0.1667, 0.6667], + [0.416667], ] # should return 0 for both cases -TEST_CASE_3 = [ +TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 3) { "y_pred": torch.tensor( [ @@ -68,7 +68,7 @@ ), "include_background": True, }, - [0.0, 0.0], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ] TEST_CASE_4 = [ @@ -87,11 +87,11 @@ ] ), }, - [0.5455], + [0.678571, 0.2, 0.333333], ] TEST_CASE_5 = [ - {"include_background": True, "reduction": "sum_batch"}, + {"include_background": True, "reduction": "sum"}, { "y_pred": torch.tensor( [ @@ -106,16 +106,28 @@ ] ), }, - 1.0455, + [1.045455], ] -TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]] +TEST_CASE_6 = [ + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] -TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]] +TEST_CASE_7 = [ + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] -TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]] +TEST_CASE_8 = [ + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] -TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]] +TEST_CASE_9 = [ + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] class TestComputeGeneralizedDiceScore(unittest.TestCase): @@ -126,7 +138,7 @@ def test_device(self, input_data, _expected_value): np.testing.assert_equal(result.device, input_data["y_pred"].device) # Functional part tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) def test_value(self, input_data, expected_value): result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) @@ -146,7 +158,7 @@ def test_value_class(self, input_data, expected_value): vals["y"] = input_data.pop("y") generalized_dice_score = GeneralizedDiceScore(**input_data) generalized_dice_score(**vals) - result = generalized_dice_score.aggregate(reduction="none") + result = generalized_dice_score.aggregate() np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) # Aggregation tests diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 7b0e69f2c80..bfdf25ec6ec 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -17,14 +17,12 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi +_, has_einops = optional_import("einops") TEST_CASES = [ [ @@ -103,8 +101,8 @@ TEST_CASES_ERROR = [ [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None}, - "ControlNet expects dimension of the cross-attention conditioning " - "(cross_attention_dim) when using with_conditioning.", + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True.", ], [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2}, @@ -112,7 +110,8 @@ ], [ {"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16}, - "ControlNet expects all num_channels being multiple of norm_num_groups", + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={(8, 16)} and norm_num_groups={16}", ], [ { @@ -122,16 +121,17 @@ "attention_levels": (True,), "norm_num_groups": 8, }, - "ControlNet expects num_channels being same size of attention_levels", + f"ControlNet expects channels to have the same length as attention_levels, but got " + f"channels={(8, 16)} and attention_levels={(True,)}", ], ] @SkipIfBeforePyTorchVersion((2, 0)) -@skipUnless(has_generative, "monai-generative required") class TestControlNet(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net): @@ -145,6 +145,7 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_ self.assertEqual(result[1].shape, expected_shape) @parameterized.expand(TEST_CASES_CONDITIONAL) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net): diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 44458147d6e..e034e422903 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose einops, has_einops = optional_import("einops") @@ -166,6 +166,21 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + @parameterized.expand([[True], [False]]) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_attention(self, causal): + input_param = {"hidden_size": 128, "num_heads": 1, "causal": causal, "sequence_length": 16 if causal else None} + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(1, 16, 128).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index 059a4a4ba82..f9384e6d828 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -17,14 +17,11 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks import eval_mode from monai.utils import optional_import _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi UNCOND_CASES_2D = [ [ @@ -291,7 +288,6 @@ ] -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -510,7 +506,6 @@ def test_shape_with_additional_inputs(self, input_param): self.assertEqual(result.shape, (1, 1, 16, 16)) -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py index 28e0b69621f..3e42bda35dc 100644 --- a/tests/test_handler_ignite_metric.py +++ b/tests/test_handler_ignite_metric.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.handlers import IgniteMetric, IgniteMetricHandler, from_engine +from monai.handlers import IgniteMetricHandler, from_engine from monai.losses import DiceLoss from monai.metrics import LossMetric from tests.utils import SkipIfNoModule, assert_allclose, optional_import @@ -172,7 +172,7 @@ def _val_func(engine, batch): @parameterized.expand(TEST_CASES[0:2]) def test_old_ignite_metric(self, input_param, input_data, expected_val): loss_fn = DiceLoss(**input_param) - ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) + ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) def _val_func(engine, batch): pass diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index f31a07eba49..60b60197039 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -448,7 +448,7 @@ def test_shape(self): def test_astype(self): t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"}) - for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.compat.long, np.uint16): + for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.uint16): self.assertIsInstance(t.astype(np_types), np.ndarray) for pt_types in ("torch.float", torch.float, "torch.float64"): self.assertIsInstance(t.astype(pt_types), torch.Tensor) diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index 6f294157594..422e8c4b9df 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t +from monai.transforms.utils_morphological_ops import dilate, erode, get_morphological_filter_result_t from tests.utils import TEST_NDARRAYS, assert_allclose TESTS_SHAPE = [] diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index 51ec275cf41..704bbdb9b13 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -47,6 +47,7 @@ TEST_CASES = [ [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442], + [{"classes": 3, "dim": 2}, {"inputs": inputs.repeat(4, 1, 1, 1), "targets": targets.repeat(4, 1, 1)}, 1.1442], [{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433], [{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469], [{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269], diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py index 4475d8aaabf..f8531dc08f4 100644 --- a/tests/test_nifti_endianness.py +++ b/tests/test_nifti_endianness.py @@ -82,7 +82,7 @@ def test_switch(self): # verify data types after = switch_endianness(before) np.testing.assert_allclose(after.astype(float), expected_float) - before = np.array(["1.12", "-9.2", "42"], dtype=np.string_) + before = np.array(["1.12", "-9.2", "42"], dtype=np.bytes_) after = switch_endianness(before) np.testing.assert_array_equal(before, after) diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index d0591450334..71ac767966a 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -43,7 +43,7 @@ "patch_size": (patch_size,) * nd, "hidden_size": hidden_size, "num_heads": num_heads, - "pos_embed": proj_type, + "proj_type": proj_type, "pos_embed_type": pos_embed_type, "dropout_rate": dropout_rate, }, @@ -127,7 +127,7 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=128, num_heads=12, - pos_embed="conv", + proj_type="conv", pos_embed_type="sincos", dropout_rate=5.0, ) @@ -139,7 +139,7 @@ def test_ill_arg(self): patch_size=(64, 64, 64), hidden_size=512, num_heads=8, - pos_embed="perceptron", + proj_type="perceptron", pos_embed_type="sincos", dropout_rate=0.3, ) @@ -151,7 +151,7 @@ def test_ill_arg(self): patch_size=(8, 8, 8), hidden_size=512, num_heads=14, - pos_embed="conv", + proj_type="conv", dropout_rate=0.3, ) @@ -162,7 +162,7 @@ def test_ill_arg(self): patch_size=(4, 4, 4), hidden_size=768, num_heads=8, - pos_embed="perceptron", + proj_type="perceptron", dropout_rate=0.3, ) with self.assertRaises(ValueError): @@ -183,7 +183,7 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=768, num_heads=12, - pos_embed="perc", + proj_type="perc", dropout_rate=0.3, ) diff --git a/tests/test_point_based_window_inferer.py b/tests/test_point_based_window_inferer.py new file mode 100644 index 00000000000..1b293288c4d --- /dev/null +++ b/tests/test_point_based_window_inferer.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.inferer import point_based_window_inferer +from monai.networks import eval_mode +from monai.networks.nets.vista3d import vista3d132 +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick + +device = "cuda" if torch.cuda.is_available() else "cpu" + +_, has_tqdm = optional_import("tqdm") + +TEST_CASES = [ + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device), + }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device), + "point_start": 1, + }, + ], +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick +class TestPointBasedWindowInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vista3d(self, vista3d_params, inputs_shape, inferer_params): + vista3d = vista3d132(**vista3d_params).to(device) + with eval_mode(vista3d): + inferer_params["predictor"] = vista3d + inferer_params["inputs"] = torch.randn(*inputs_shape).to(device) + stitched_output = point_based_window_inferer(**inferer_params) + self.assertEqual(stitched_output.shape, inputs_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py index 5372fcc8ae4..eab7bac9a0d 100644 --- a/tests/test_segresnet_ds.py +++ b/tests/test_segresnet_ds.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import SegResNetDS +from monai.networks.nets import SegResNetDS, SegResNetDS2 from tests.utils import SkipIfBeforePyTorchVersion, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -71,7 +71,7 @@ ] -class TestResNetDS(unittest.TestCase): +class TestSegResNetDS(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET_DS) def test_shape(self, input_param, input_shape, expected_shape): @@ -80,47 +80,71 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + @parameterized.expand(TEST_CASE_SEGRESNET_DS) + def test_shape_ds2(self, input_param, input_shape, expected_shape): + net = SegResNetDS2(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device), with_label=False) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[1] == []) + + result = net(torch.randn(input_shape).to(device), with_point=False) + self.assertEqual(result[1].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[0] == []) + @parameterized.expand(TEST_CASE_SEGRESNET_DS2) def test_shape2(self, input_param, input_shape, expected_shape): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - if dsdepth > 1: - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual( - result[i].shape, - expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), - msg=str(input_param), - ) - else: - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) - - net.eval() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_label=False)[0] + else: + result = net(torch.randn(input_shape).to(device)) + if dsdepth > 1: + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual( + result[i].shape, + expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), + msg=str(input_param), + ) + else: + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + if not isinstance(net, SegResNetDS2): + # eval mode of SegResNetDS2 has same output as training mode + # so only test eval mode for SegResNetDS + net.eval() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) @parameterized.expand(TEST_CASE_SEGRESNET_DS3) def test_shape3(self, input_param, input_shape, expected_shapes): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_point=False)[1] + else: + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) def test_ill_arg(self): with self.assertRaises(ValueError): SegResNetDS(spatial_dims=4) + with self.assertRaises(ValueError): + SegResNetDS2(spatial_dims=4) + @SkipIfBeforePyTorchVersion((1, 10)) def test_script(self): input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0] diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 3e98f4c5c46..88919fd8b19 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save einops, has_einops = optional_import("einops") @@ -32,20 +32,23 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - for flash_attn in [True, False]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, - "input_size": input_size, - "use_flash_attention": flash_attn, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for include_fc in [True, False]: + for use_combined_linear in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + "use_flash_attention": True if rel_pos_embedding is None else False, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -175,6 +178,39 @@ def count_sablock_params(*args, **kwargs): nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads) + @parameterized.expand([[True, False], [True, True], [False, True], [False, False]]) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_script(self, include_fc, use_combined_linear): + input_param = { + "hidden_size": 360, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": None, + "input_size": (16, 32), + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + } + net = SABlock(**input_param) + input_shape = (2, 512, 360) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_attention(self): + for causal in [True, False]: + input_param = {"hidden_size": 360, "num_heads": 4, "input_size": (16, 32), "causal": causal} + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(2, 512, 360).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py index a3ee623cc5c..2be4bd8600c 100644 --- a/tests/test_signal_fillempty.py +++ b/tests/test_signal_fillempty.py @@ -30,7 +30,7 @@ class TestSignalFillEmptyNumpy(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty) sig = np.load(TEST_SIGNAL) - sig[:, 123] = np.NAN + sig[:, 123] = np.nan fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) self.assertTrue(not np.isnan(fillemptysignal).any()) @@ -42,7 +42,7 @@ class TestSignalFillEmptyTorch(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty) sig = convert_to_tensor(np.load(TEST_SIGNAL)) - sig[:, 123] = convert_to_tensor(np.NAN) + sig[:, 123] = convert_to_tensor(np.nan) fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) self.assertTrue(not torch.isnan(fillemptysignal).any()) diff --git a/tests/test_signal_fillemptyd.py b/tests/test_signal_fillemptyd.py index ee8c571ef8b..77102794959 100644 --- a/tests/test_signal_fillemptyd.py +++ b/tests/test_signal_fillemptyd.py @@ -30,7 +30,7 @@ class TestSignalFillEmptyNumpy(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) sig = np.load(TEST_SIGNAL) - sig[:, 123] = np.NAN + sig[:, 123] = np.nan data = {} data["signal"] = sig fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) @@ -46,7 +46,7 @@ class TestSignalFillEmptyTorch(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) sig = convert_to_tensor(np.load(TEST_SIGNAL)) - sig[:, 123] = convert_to_tensor(np.NAN) + sig[:, 123] = convert_to_tensor(np.nan) data = {} data["signal"] = sig fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) diff --git a/tests/test_unetr.py b/tests/test_unetr.py index 46018d2bc0b..8c5ecb32e16 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -30,7 +30,7 @@ for num_heads in [8]: for mlp_dim in [3072]: for norm_name in ["instance"]: - for pos_embed in ["perceptron"]: + for proj_type in ["perceptron"]: for nd in (2, 3): test_case = [ { @@ -42,7 +42,7 @@ "norm_name": norm_name, "mlp_dim": mlp_dim, "num_heads": num_heads, - "pos_embed": pos_embed, + "proj_type": proj_type, "dropout_rate": dropout_rate, "conv_block": True, "res_block": False, @@ -75,7 +75,7 @@ def test_ill_arg(self): hidden_size=128, mlp_dim=3072, num_heads=12, - pos_embed="conv", + proj_type="conv", norm_name="instance", dropout_rate=5.0, ) @@ -89,7 +89,7 @@ def test_ill_arg(self): hidden_size=512, mlp_dim=3072, num_heads=12, - pos_embed="conv", + proj_type="conv", norm_name="instance", dropout_rate=0.5, ) @@ -103,7 +103,7 @@ def test_ill_arg(self): hidden_size=512, mlp_dim=3072, num_heads=14, - pos_embed="conv", + proj_type="conv", norm_name="batch", dropout_rate=0.4, ) @@ -117,13 +117,13 @@ def test_ill_arg(self): hidden_size=768, mlp_dim=3072, num_heads=12, - pos_embed="perc", + proj_type="perc", norm_name="instance", dropout_rate=0.2, ) @parameterized.expand(TEST_CASE_UNETR) - @SkipIfBeforePyTorchVersion((1, 9)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, input_param, input_shape, _): net = UNETR(**(input_param)) net.eval() diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 6e655289e47..90c0401e462 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import mode, percentile +from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick @@ -27,6 +27,13 @@ TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) +TEST_MIN_MAX = [] +for p in TEST_NDARRAYS: + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])]) + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])]) + class TestPytorchNumpyUnification(unittest.TestCase): @@ -74,6 +81,11 @@ def test_mode(self, array, expected, to_long): res = mode(array, to_long=to_long) assert_allclose(res, expected) + @parameterized.expand(TEST_MIN_MAX) + def test_min_max(self, array, input_params, func, expected): + res = func(array, **input_params) + assert_allclose(res, expected, type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py new file mode 100644 index 00000000000..d3b4e0c10e0 --- /dev/null +++ b/tests/test_vista3d.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import VISTA3D, SegResNetDS2 +from monai.networks.nets.vista3d import ClassMappingClassify, PointMappingSAM +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + [{"encoder_embed_dim": 48, "in_channels": 1}, {}, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64)], + [{"encoder_embed_dim": 48, "in_channels": 2}, {}, (1, 2, 64, 64, 64), (1, 1, 64, 64, 64)], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + {"class_vector": torch.tensor([1, 2, 3], device=device)}, + (1, 1, 64, 64, 64), + (3, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + { + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, + (1, 1, 64, 64, 64), + (1, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + { + "class_vector": torch.tensor([1, 2], device=device), + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0], [1, 0]], device=device), + }, + (1, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick +class TestVista3d(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_vista3d_shape(self, args, input_params, input_shape, expected_shape): + segresnet = SegResNetDS2( + in_channels=args["in_channels"], + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=args["encoder_embed_dim"], + init_filters=args["encoder_embed_dim"], + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=args["encoder_embed_dim"], n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True) + net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device) + with eval_mode(net): + result = net.forward( + torch.randn(input_shape).to(device), + point_coords=input_params.get("point_coords", None), + point_labels=input_params.get("point_labels", None), + class_vector=input_params.get("class_vector", None), + ) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_sampler.py b/tests/test_vista3d_sampler.py new file mode 100644 index 00000000000..6945d250d20 --- /dev/null +++ b/tests/test_vista3d_sampler.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.sampler import sample_prompt_pairs + +label = torch.zeros([1, 1, 64, 64, 64]) +label[:, :, :10, :10, :10] = 1 +label[:, :, 20:30, 20:30, 20:30] = 2 +label[:, :, 30:40, 30:40, 30:40] = 3 +label1 = torch.zeros([1, 1, 64, 64, 64]) + +TEST_VISTA_SAMPLE_PROMPT = [ + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 0, + }, + [4, 4, 4, 4], + ], + [ + { + "labels": label, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [2, None, None, 2], + ], + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 1, + "drop_point_prob": 0, + }, + [None, 3, 3, 3], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [1, None, None, 1], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 0, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [None, None, None, None], + ], +] + + +class TestGeneratePrompt(unittest.TestCase): + @parameterized.expand(TEST_VISTA_SAMPLE_PROMPT) + def test_result(self, input_data, expected): + output = sample_prompt_pairs(**input_data) + result = [i.shape[0] if i is not None else None for i in output] + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py new file mode 100644 index 00000000000..9d61fe2fc23 --- /dev/null +++ b/tests/test_vista3d_transforms.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd +from monai.utils import min_version +from monai.utils.module import optional_import + +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +TEST_VISTA_PRETRANSFORM = [ + [ + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [1]}, + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [3]}, + ], + [ + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [2]}, + ], + [ + {"label_prompt": [3], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [4, 5], "points": [[0, 0, 0]], "point_labels": [0]}, + ], + [ + {"label_prompt": [6], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [7, 8], "points": [[0, 0, 0]], "point_labels": [0]}, + ], +] + + +pred1 = torch.zeros([2, 64, 64, 64]) +pred1[0, :10, :10, :10] = 1 +pred1[1, 20:30, 20:30, 20:30] = 1 +output1 = torch.zeros([1, 64, 64, 64]) +output1[:, :10, :10, :10] = 2 +output1[:, 20:30, 20:30, 20:30] = 3 + +# -1 is needed since pred should be before sigmoid. +pred2 = torch.zeros([1, 64, 64, 64]) - 1 +pred2[:, :10, :10, :10] = 1 +pred2[:, 20:30, 20:30, 20:30] = 1 +output2 = torch.zeros([1, 64, 64, 64]) +output2[:, 20:30, 20:30, 20:30] = 1 + +TEST_VISTA_POSTTRANSFORM = [ + [{"pred": pred1.to(device), "label_prompt": torch.tensor([2, 3]).to(device)}, output1.to(device)], + [ + { + "pred": pred2.to(device), + "points": torch.tensor([[25, 25, 25]]).to(device), + "point_labels": torch.tensor([1]).to(device), + }, + output2.to(device), + ], +] + + +class TestVistaPreTransformd(unittest.TestCase): + @parameterized.expand(TEST_VISTA_PRETRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPreTransformd(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2]) + result = transform(input_data) + self.assertEqual(result, expected) + + +@skipUnless(has_measure, "skimage.measure required") +class TestVistaPostTransformd(unittest.TestCase): + @parameterized.expand(TEST_VISTA_POSTTRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPostTransformd(keys="pred") + result = transform(input_data) + self.assertEqual((result["pred"] == expected).all(), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py new file mode 100644 index 00000000000..5a0caedd61d --- /dev/null +++ b/tests/test_vista3d_utils.py @@ -0,0 +1,162 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.utils import convert_points_to_disc, keep_merge_components_with_points, sample_points_from_label +from monai.utils import min_version +from monai.utils.module import optional_import +from tests.utils import skip_if_no_cuda, skip_if_quick + +cp, has_cp = optional_import("cupy") +cucim_skimage, has_cucim = optional_import("cucim.skimage") +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +TESTS_SAMPLE_POINTS_FROM_LABEL = [] +for use_center in [True, False]: + labels = torch.zeros(1, 1, 32, 32, 32) + labels[0, 0, 5:10, 5:10, 5:10] = 1 + labels[0, 0, 10:15, 10:15, 10:15] = 3 + labels[0, 0, 20:25, 20:25, 20:25] = 5 + TESTS_SAMPLE_POINTS_FROM_LABEL.append( + [{"labels": labels, "label_set": (1, 3, 5), "use_center": use_center}, (3, 1, 3), (3, 1)] + ) + +TEST_CONVERT_POINTS_TO_DISC = [] +for radius in [1, 2]: + for disc in [True, False]: + image_size = (32, 32, 32) + point = torch.randn(3, 1, 3) + point_label = torch.randint(0, 4, (3, 1)) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + image_size = (16, 32, 64) + point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) + point_label = torch.tensor([[1, 0]]) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + +TEST_CONVERT_POINTS_TO_DISC_VALUE = [] +image_size = (16, 32, 64) +point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) +point_label = torch.tensor([[1, 0]]) +expected_shape = (point.shape[0], 2, *image_size) +for radius in [5, 10]: + for disc in [True, False]: + TEST_CONVERT_POINTS_TO_DISC_VALUE.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + [point, point_label], + ] + ) + + +TEST_LCC_MASK_POINT_TORCH = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 128, 32, 32) + TEST_LCC_MASK_POINT_TORCH.append( + [ + { + "img_pos": torch.randint(0, 2, shape, dtype=torch.bool), + "img_neg": torch.randint(0, 2, shape, dtype=torch.bool), + "point_coords": torch.randint(0, 10, (bs, num_points, 3)), + "point_labels": torch.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + +TEST_LCC_MASK_POINT_NP = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 32, 32, 64) + TEST_LCC_MASK_POINT_NP.append( + [ + { + "img_pos": np.random.randint(0, 2, shape, dtype=bool), + "img_neg": np.random.randint(0, 2, shape, dtype=bool), + "point_coords": np.random.randint(0, 5, (bs, num_points, 3)), + "point_labels": np.random.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + + +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") +class TestSamplePointsFromLabel(unittest.TestCase): + + @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL) + def test_shape(self, input_data, expected_point_shape, expected_point_label_shape): + point, point_label = sample_points_from_label(**input_data) + self.assertEqual(point.shape, expected_point_shape) + self.assertEqual(point_label.shape, expected_point_label_shape) + + +class TestConvertPointsToDisc(unittest.TestCase): + + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC) + def test_shape(self, input_data, expected_shape): + result = convert_points_to_disc(**input_data) + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC_VALUE) + def test_value(self, input_data, points): + result = convert_points_to_disc(**input_data) + point, point_label = points + for i in range(point.shape[0]): + for j in range(point.shape[1]): + self.assertEqual(result[i, point_label[i, j], point[i, j][0], point[i, j][1], point[i, j][2]], True) + + +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") +class TestKeepMergeComponentsWithPoints(unittest.TestCase): + + @skip_if_quick + @skip_if_no_cuda + @skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_TORCH) + def test_cp_shape(self, input_data, shape): + for key in input_data: + input_data[key] = input_data[key].to(device) + mask = keep_merge_components_with_points(**input_data) + self.assertEqual(mask.shape, shape) + + @skipUnless(has_measure, "skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_NP) + def test_np_shape(self, input_data, shape): + mask = keep_merge_components_with_points(**input_data) + self.assertEqual(mask.shape, shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vit.py b/tests/test_vit.py index d638c0116af..a3ffd0b2ef7 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -106,7 +106,7 @@ def test_ill_arg( ) @parameterized.expand(TEST_CASE_Vit[:1]) - @SkipIfBeforePyTorchVersion((1, 9)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, input_param, input_shape, _): net = ViT(**(input_param)) net.eval() diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index c68c583a0e7..9a503948d04 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -23,7 +23,7 @@ for in_channels in [1, 4]: for img_size in [64, 96, 128]: for patch_size in [16]: - for pos_embed in ["conv", "perceptron"]: + for proj_type in ["conv", "perceptron"]: for nd in [2, 3]: test_case = [ { @@ -34,7 +34,7 @@ "mlp_dim": 3072, "num_layers": 4, "num_heads": 12, - "pos_embed": pos_embed, + "proj_type": proj_type, "dropout_rate": 0.6, "spatial_dims": nd, }, @@ -54,7 +54,7 @@ "mlp_dim": 3072, "num_layers": 4, "num_heads": 12, - "pos_embed": "conv", + "proj_type": "conv", "dropout_rate": 0.6, "spatial_dims": 3, }, @@ -93,7 +93,7 @@ def test_shape(self, input_param, input_shape, expected_shape): ] ) def test_ill_arg( - self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, pos_embed, dropout_rate + self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, proj_type, dropout_rate ): with self.assertRaises(ValueError): ViTAutoEnc( @@ -104,7 +104,7 @@ def test_ill_arg( mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, - pos_embed=pos_embed, + proj_type=proj_type, dropout_rate=dropout_rate, )