From baabe0b3a3258d52e7316fbea70a8132d90e1b33 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 8 Aug 2024 14:32:09 -0400 Subject: [PATCH 1/6] Add mednext implementation Signed-off-by: Suraj Pai --- monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/mednext_block.py | 233 +++++++++++++++++++++++ monai/networks/nets/__init__.py | 1 + monai/networks/nets/mednext.py | 244 +++++++++++++++++++++++++ tests/test_mednext.py | 98 ++++++++++ 5 files changed, 577 insertions(+) create mode 100644 monai/networks/blocks/mednext_block.py create mode 100644 monai/networks/nets/mednext.py create mode 100644 tests/test_mednext.py diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 47abc4a1c4..a535c0ab26 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -26,6 +26,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7 from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock +from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock from .mlp import MLPBlock from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py new file mode 100644 index 0000000000..42617d2e56 --- /dev/null +++ b/monai/networks/blocks/mednext_block.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this code are derived from the original repository at: +# https://github.com/MIC-DKFZ/MedNeXt +# and are used under the terms of the Apache License, Version 2.0. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MedNeXtBlock(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + exp_r: int = 4, + kernel_size: int = 7, + do_res: int = True, + norm_type: str = "group", + n_groups: int or None = None, + dim="3d", + grn=False, + ): + + super().__init__() + + self.do_res = do_res + + assert dim in ["2d", "3d"] + self.dim = dim + if self.dim == "2d": + conv = nn.Conv2d + else: + conv = nn.Conv3d + + # First convolution layer with DepthWise Convolutions + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=in_channels if n_groups is None else n_groups, + ) + + # Normalization Layer. GroupNorm is used by default. + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) + elif norm_type == "layer": + self.norm = LayerNorm(normalized_shape=in_channels, data_format="channels_first") + + # Second convolution (Expansion) layer with Conv3D 1x1x1 + self.conv2 = conv(in_channels=in_channels, out_channels=exp_r * in_channels, kernel_size=1, stride=1, padding=0) + + # GeLU activations + self.act = nn.GELU() + + # Third convolution (Compression) layer with Conv3D 1x1x1 + self.conv3 = conv( + in_channels=exp_r * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + self.grn = grn + if self.grn: + if dim == "2d": + self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) + else: + self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) + + def forward(self, x, dummy_tensor=None): + + x1 = x + x1 = self.conv1(x1) + x1 = self.act(self.conv2(self.norm(x1))) + + if self.grn: + # gamma, beta: learnable affine transform parameters + # X: input of shape (N,C,H,W,D) + if self.dim == "2d": + gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True) + else: + gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6) + x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1 + x1 = self.conv3(x1) + if self.do_res: + x1 = x + x1 + return x1 + + +class MedNeXtDownBlock(MedNeXtBlock): + + def __init__( + self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False + ): + + super().__init__( + in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn + ) + + if dim == "2d": + conv = nn.Conv2d + else: + conv = nn.Conv3d + self.resample_do_res = do_res + if do_res: + self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) + + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x, dummy_tensor=None): + + x1 = super().forward(x) + + if self.resample_do_res: + res = self.res_conv(x) + x1 = x1 + res + + return x1 + + +class MedNeXtUpBlock(MedNeXtBlock): + + def __init__( + self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False + ): + super().__init__( + in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn + ) + + self.resample_do_res = do_res + + self.dim = dim + if dim == "2d": + conv = nn.ConvTranspose2d + else: + conv = nn.ConvTranspose3d + if do_res: + self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) + + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x, dummy_tensor=None): + + x1 = super().forward(x) + # Asymmetry but necessary to match shape + + if self.dim == "2d": + x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0)) + else: + x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0)) + + if self.resample_do_res: + res = self.res_conv(x) + if self.dim == "2d": + res = torch.nn.functional.pad(res, (1, 0, 1, 0)) + else: + res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0)) + x1 = x1 + res + + return x1 + + +class OutBlock(nn.Module): + + def __init__(self, in_channels, n_classes, dim): + super().__init__() + + if dim == "2d": + conv = nn.ConvTranspose2d + else: + conv = nn.ConvTranspose3d + self.conv_out = conv(in_channels, n_classes, kernel_size=1) + + def forward(self, x, dummy_tensor=None): + return self.conv_out(x) + + +class LayerNorm(nn.Module): + """LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta + self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x, dummy_tensor=False): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] + return x diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c777fe6442..f62fe432fa 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,7 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .mednext import MedNeXt from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py new file mode 100644 index 0000000000..293db7d443 --- /dev/null +++ b/monai/networks/nets/mednext.py @@ -0,0 +1,244 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this code are derived from the original repository at: +# https://github.com/MIC-DKFZ/MedNeXt +# and are used under the terms of the Apache License, Version 2.0. + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.networks.blocks import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock + + +class MedNeXt(nn.Module): + """ + MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975 + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + enc_exp_r: expansion ratio for encoder blocks. Defaults to 2. + dec_exp_r: expansion ratio for decoder blocks. Defaults to 2. + bottlenec_exp_r: expansion ratio for bottleneck blocks. Defaults to 2. + kernel_size: kernel size for convolutions. Defaults to 7. + deep_supervision: whether to use deep supervision. Defaults to False. + do_res: whether to use residual connections. Defaults to False. + do_res_up_down: whether to use residual connections in up and down blocks. Defaults to False. + blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2]. + blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. + blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2]. + norm_type: type of normalization layer. Defaults to 'group'. + grn: whether to use Global Response Normalization (GRN). Defaults to False. + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + enc_exp_r: int = 2, + dec_exp_r: int = 2, + bottlenec_exp_r: int = 2, + kernel_size: int = 7, + deep_supervision: bool = False, + do_res: bool = False, + do_res_up_down: bool = False, + blocks_down: list = [2, 2, 2, 2], + blocks_bottleneck: int = 2, + blocks_up: list = [2, 2, 2, 2], + norm_type: str = "group", + grn: bool = False, + ): + """ + Initialize the MedNeXt model. + + This method sets up the architecture of the model, including: + - Stem convolution + - Encoder stages and downsampling blocks + - Bottleneck blocks + - Decoder stages and upsampling blocks + - Output blocks for deep supervision (if enabled) + """ + super().__init__() + + self.do_ds = deep_supervision + assert spatial_dims in [2, 3], "`spatial_dims` can only be 2 or 3." + spatial_dims_str = f"{spatial_dims}d" + enc_kernel_size = dec_kernel_size = kernel_size + + if isinstance(enc_exp_r, int): + enc_exp_r = [enc_exp_r] * len(blocks_down) + + if isinstance(dec_exp_r, int): + dec_exp_r = [dec_exp_r] * len(blocks_up) + + conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d + + self.stem = conv(in_channels, init_filters, kernel_size=1) + + enc_stages = [] + down_blocks = [] + + for i, num_blocks in enumerate(blocks_down): + enc_stages.append( + nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2**i), + out_channels=init_filters * (2**i), + exp_r=enc_exp_r[i], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=spatial_dims_str, + grn=grn, + ) + for _ in range(num_blocks) + ] + ) + ) + + down_blocks.append( + MedNeXtDownBlock( + in_channels=init_filters * (2**i), + out_channels=init_filters * (2 ** (i + 1)), + exp_r=enc_exp_r[i], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=spatial_dims_str, + ) + ) + + self.enc_stages = nn.ModuleList(enc_stages) + self.down_blocks = nn.ModuleList(down_blocks) + + self.bottleneck = nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2 ** len(blocks_down)), + out_channels=init_filters * (2 ** len(blocks_down)), + exp_r=bottlenec_exp_r, + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=spatial_dims_str, + grn=grn, + ) + for _ in range(blocks_bottleneck) + ] + ) + + up_blocks = [] + dec_stages = [] + for i, num_blocks in enumerate(blocks_up): + up_blocks.append( + MedNeXtUpBlock( + in_channels=init_filters * (2 ** (len(blocks_up) - i)), + out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + exp_r=dec_exp_r[i], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=spatial_dims_str, + grn=grn, + ) + ) + + dec_stages.append( + nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + exp_r=dec_exp_r[i], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=spatial_dims_str, + grn=grn, + ) + for _ in range(num_blocks) + ] + ) + ) + + self.up_blocks = nn.ModuleList(up_blocks) + self.dec_stages = nn.ModuleList(dec_stages) + + self.out_0 = OutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str) + + if deep_supervision: + out_blocks = [ + OutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str) + for i in range(1, len(blocks_up) + 1) + ] + + out_blocks.reverse() + self.out_blocks = nn.ModuleList(out_blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: + """ + Forward pass of the MedNeXt model. + + This method performs the forward pass through the model, including: + - Stem convolution + - Encoder stages and downsampling + - Bottleneck blocks + - Decoder stages and upsampling with skip connections + - Output blocks for deep supervision (if enabled) + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor or list[torch.Tensor]: Output tensor(s). + """ + # Apply stem convolution + x = self.stem(x) + + # Encoder forward pass + enc_outputs = [] + for enc_stage, down_block in zip(self.enc_stages, self.down_blocks): + x = enc_stage(x) + enc_outputs.append(x) + x = down_block(x) + + # Bottleneck forward pass + x = self.bottleneck(x) + + # Initialize deep supervision outputs if enabled + if self.do_ds: + ds_outputs = [] + + # Decoder forward pass with skip connections + for i, (up_block, dec_stage) in enumerate(zip(self.up_blocks, self.dec_stages)): + if self.do_ds and i < len(self.out_blocks): + ds_outputs.append(self.out_blocks[i](x)) + + x = up_block(x) + x = x + enc_outputs[-(i + 1)] + x = dec_stage(x) + + # Final output block + x = self.out_0(x) + + # Return output(s) + if self.do_ds and self.training: + return (x, *ds_outputs[::-1]) + else: + return x diff --git a/tests/test_mednext.py b/tests/test_mednext.py new file mode 100644 index 0000000000..e5f74118c3 --- /dev/null +++ b/tests/test_mednext.py @@ -0,0 +1,98 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MedNeXt +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_MEDNEXT = [] +for spatial_dims in range(2, 4): + for init_filters in [8, 16]: + for deep_supervision in [False, True]: + for do_res in [False, True]: + for do_res_up_down in [False, True]: + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": init_filters, + "deep_supervision": deep_supervision, + "do_res": do_res, + "do_res_up_down": do_res_up_down, + }, + (2, 1, *([16] * spatial_dims)), + (2, 2, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT.append(test_case) + +TEST_CASE_MEDNEXT_2 = [] +for spatial_dims in range(2, 4): + for out_channels in [1, 2]: + for deep_supervision in [False, True]: + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": 8, + "out_channels": out_channels, + "deep_supervision": deep_supervision, + }, + (2, 1, *([16] * spatial_dims)), + (2, out_channels, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT_2.append(test_case) + + +class TestMedNeXt(unittest.TestCase): + + @parameterized.expand(TEST_CASE_MEDNEXT) + def test_shape(self, input_param, input_shape, expected_shape): + net = MedNeXt(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + if input_param["deep_supervision"] and net.training: + assert isinstance(result, tuple) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + else: + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + @parameterized.expand(TEST_CASE_MEDNEXT_2) + def test_shape2(self, input_param, input_shape, expected_shape): + net = MedNeXt(**input_param).to(device) + + net.train() + result = net(torch.randn(input_shape).to(device)) + if input_param["deep_supervision"]: + assert isinstance(result, tuple) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + else: + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + net.eval() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + MedNeXt(spatial_dims=4) + + +if __name__ == "__main__": + unittest.main() From 93e782f4accd06a72e15e8e96f574f5f27bef658 Mon Sep 17 00:00:00 2001 From: Robin CREMESE Date: Wed, 21 Aug 2024 16:50:57 +0200 Subject: [PATCH 2/6] Code formating for Blake and Flake8 checks to pass + integration of MedNext variants (S, B, M, L) + integration of remarks from @johnzilke (https://github.com/Project-MONAI/MONAI/pull/8004#pullrequestreview-2233276224) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder Signed-off-by: Robin CREMESE --- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/mednext_block.py | 123 ++++++++------- monai/networks/nets/__init__.py | 20 ++- monai/networks/nets/mednext.py | 207 +++++++++++++++++++++---- tests/test_mednext.py | 24 ++- 5 files changed, 273 insertions(+), 103 deletions(-) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index a535c0ab26..499caf2e0f 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -26,7 +26,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7 from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock -from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock +from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock from .mlp import MLPBlock from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py index 42617d2e56..8d0f0433bd 100644 --- a/monai/networks/blocks/mednext_block.py +++ b/monai/networks/blocks/mednext_block.py @@ -17,7 +17,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F + +all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"] class MedNeXtBlock(nn.Module): @@ -26,26 +27,30 @@ def __init__( self, in_channels: int, out_channels: int, - exp_r: int = 4, + expansion_ratio: int = 4, kernel_size: int = 7, - do_res: int = True, + use_residual_connection: int = True, norm_type: str = "group", - n_groups: int or None = None, dim="3d", grn=False, ): super().__init__() - self.do_res = do_res + self.do_res = use_residual_connection assert dim in ["2d", "3d"] self.dim = dim if self.dim == "2d": conv = nn.Conv2d - else: + normalized_shape = [in_channels, kernel_size, kernel_size] + grn_parameter_shape = (1, 1) + elif self.dim == "3d": conv = nn.Conv3d - + normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size] + grn_parameter_shape = (1, 1, 1) + else: + raise ValueError("dim must be either '2d' or '3d'") # First convolution layer with DepthWise Convolutions self.conv1 = conv( in_channels=in_channels, @@ -53,36 +58,34 @@ def __init__( kernel_size=kernel_size, stride=1, padding=kernel_size // 2, - groups=in_channels if n_groups is None else n_groups, + groups=in_channels, ) # Normalization Layer. GroupNorm is used by default. if norm_type == "group": self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) elif norm_type == "layer": - self.norm = LayerNorm(normalized_shape=in_channels, data_format="channels_first") - + self.norm = nn.LayerNorm(normalized_shape=normalized_shape) # Second convolution (Expansion) layer with Conv3D 1x1x1 - self.conv2 = conv(in_channels=in_channels, out_channels=exp_r * in_channels, kernel_size=1, stride=1, padding=0) + self.conv2 = conv( + in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0 + ) # GeLU activations self.act = nn.GELU() # Third convolution (Compression) layer with Conv3D 1x1x1 self.conv3 = conv( - in_channels=exp_r * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) self.grn = grn if self.grn: - if dim == "2d": - self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) - self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) - else: - self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) - self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) + grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape + self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) - def forward(self, x, dummy_tensor=None): + def forward(self, x): x1 = x x1 = self.conv1(x1) @@ -106,19 +109,34 @@ def forward(self, x, dummy_tensor=None): class MedNeXtDownBlock(MedNeXtBlock): def __init__( - self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + grn: bool = False, ): super().__init__( - in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + grn=grn, ) if dim == "2d": conv = nn.Conv2d else: conv = nn.Conv3d - self.resample_do_res = do_res - if do_res: + self.resample_do_res = use_residual_connection + if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) self.conv1 = conv( @@ -130,7 +148,7 @@ def __init__( groups=in_channels, ) - def forward(self, x, dummy_tensor=None): + def forward(self, x): x1 = super().forward(x) @@ -144,20 +162,35 @@ def forward(self, x, dummy_tensor=None): class MedNeXtUpBlock(MedNeXtBlock): def __init__( - self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + grn: bool = False, ): super().__init__( - in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + grn=grn, ) - self.resample_do_res = do_res + self.resample_do_res = use_residual_connection self.dim = dim if dim == "2d": conv = nn.ConvTranspose2d else: conv = nn.ConvTranspose3d - if do_res: + if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) self.conv1 = conv( @@ -169,7 +202,7 @@ def __init__( groups=in_channels, ) - def forward(self, x, dummy_tensor=None): + def forward(self, x): x1 = super().forward(x) # Asymmetry but necessary to match shape @@ -190,7 +223,7 @@ def forward(self, x, dummy_tensor=None): return x1 -class OutBlock(nn.Module): +class MedNeXtOutBlock(nn.Module): def __init__(self, in_channels, n_classes, dim): super().__init__() @@ -201,33 +234,5 @@ def __init__(self, in_channels, n_classes, dim): conv = nn.ConvTranspose3d self.conv_out = conv(in_channels, n_classes, kernel_size=1) - def forward(self, x, dummy_tensor=None): + def forward(self, x): return self.conv_out(x) - - -class LayerNorm(nn.Module): - """LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ - - def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta - self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x, dummy_tensor=False): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] - return x diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index f62fe432fa..6dde0b4cc6 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,7 +53,25 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet -from .mednext import MedNeXt +from .mednext import ( + MedNeXt, + MedNext, + MedNextB, + MedNeXtB, + MedNextBase, + MedNextL, + MedNeXtL, + MedNeXtLarge, + MedNextLarge, + MedNextM, + MedNeXtM, + MedNeXtMedium, + MedNextMedium, + MedNextS, + MedNeXtS, + MedNeXtSmall, + MedNextSmall, +) from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py index 293db7d443..e4e68bea20 100644 --- a/monai/networks/nets/mednext.py +++ b/monai/networks/nets/mednext.py @@ -15,10 +15,33 @@ from __future__ import annotations +from collections.abc import Sequence + import torch import torch.nn as nn -from monai.networks.blocks import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock +from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock + +__all__ = [ + "MedNeXt", + "MedNeXtSmall", + "MedNeXtBase", + "MedNeXtMedium", + "MedNeXtLarge", + "MedNext", + "MedNextS", + "MedNeXtS", + "MedNextSmall", + "MedNextB", + "MedNeXtB", + "MedNextBase", + "MedNextM", + "MedNeXtM", + "MedNextMedium", + "MedNextL", + "MedNeXtL", + "MedNextLarge", +] class MedNeXt(nn.Module): @@ -30,13 +53,12 @@ class MedNeXt(nn.Module): init_filters: number of output channels for initial convolution layer. Defaults to 32. in_channels: number of input channels for the network. Defaults to 1. out_channels: number of output channels for the network. Defaults to 2. - enc_exp_r: expansion ratio for encoder blocks. Defaults to 2. - dec_exp_r: expansion ratio for decoder blocks. Defaults to 2. - bottlenec_exp_r: expansion ratio for bottleneck blocks. Defaults to 2. + encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2. + decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2. + bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2. kernel_size: kernel size for convolutions. Defaults to 7. deep_supervision: whether to use deep supervision. Defaults to False. - do_res: whether to use residual connections. Defaults to False. - do_res_up_down: whether to use residual connections in up and down blocks. Defaults to False. + use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False. blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2]. blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2]. @@ -50,16 +72,15 @@ def __init__( init_filters: int = 32, in_channels: int = 1, out_channels: int = 2, - enc_exp_r: int = 2, - dec_exp_r: int = 2, - bottlenec_exp_r: int = 2, + encoder_expansion_ratio: int = 2, + decoder_expansion_ratio: int = 2, + bottleneck_expansion_ratio: int = 2, kernel_size: int = 7, deep_supervision: bool = False, - do_res: bool = False, - do_res_up_down: bool = False, - blocks_down: list = [2, 2, 2, 2], + use_residual_connection: bool = False, + blocks_down: Sequence[int] = (2, 2, 2, 2), blocks_bottleneck: int = 2, - blocks_up: list = [2, 2, 2, 2], + blocks_up: Sequence[int] = (2, 2, 2, 2), norm_type: str = "group", grn: bool = False, ): @@ -80,11 +101,11 @@ def __init__( spatial_dims_str = f"{spatial_dims}d" enc_kernel_size = dec_kernel_size = kernel_size - if isinstance(enc_exp_r, int): - enc_exp_r = [enc_exp_r] * len(blocks_down) + if isinstance(encoder_expansion_ratio, int): + encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down) - if isinstance(dec_exp_r, int): - dec_exp_r = [dec_exp_r] * len(blocks_up) + if isinstance(decoder_expansion_ratio, int): + decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up) conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d @@ -100,9 +121,9 @@ def __init__( MedNeXtBlock( in_channels=init_filters * (2**i), out_channels=init_filters * (2**i), - exp_r=enc_exp_r[i], + expansion_ratio=encoder_expansion_ratio[i], kernel_size=enc_kernel_size, - do_res=do_res, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -116,9 +137,9 @@ def __init__( MedNeXtDownBlock( in_channels=init_filters * (2**i), out_channels=init_filters * (2 ** (i + 1)), - exp_r=enc_exp_r[i], + expansion_ratio=encoder_expansion_ratio[i], kernel_size=enc_kernel_size, - do_res=do_res_up_down, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, ) @@ -132,9 +153,9 @@ def __init__( MedNeXtBlock( in_channels=init_filters * (2 ** len(blocks_down)), out_channels=init_filters * (2 ** len(blocks_down)), - exp_r=bottlenec_exp_r, + expansion_ratio=bottleneck_expansion_ratio, kernel_size=dec_kernel_size, - do_res=do_res, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -150,9 +171,9 @@ def __init__( MedNeXtUpBlock( in_channels=init_filters * (2 ** (len(blocks_up) - i)), out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), - exp_r=dec_exp_r[i], + expansion_ratio=decoder_expansion_ratio[i], kernel_size=dec_kernel_size, - do_res=do_res_up_down, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -165,9 +186,9 @@ def __init__( MedNeXtBlock( in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), - exp_r=dec_exp_r[i], + expansion_ratio=decoder_expansion_ratio[i], kernel_size=dec_kernel_size, - do_res=do_res, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -180,11 +201,11 @@ def __init__( self.up_blocks = nn.ModuleList(up_blocks) self.dec_stages = nn.ModuleList(dec_stages) - self.out_0 = OutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str) + self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str) if deep_supervision: out_blocks = [ - OutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str) + MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str) for i in range(1, len(blocks_up) + 1) ] @@ -242,3 +263,131 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: return (x, *ds_outputs[::-1]) else: return x + + +# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975 +class MedNeXtSmall(MedNeXt): + """MedNeXt Small (S) configuration""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=2, + decoder_expansion_ratio=2, + bottleneck_expansion_ratio=2, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + norm_type="group", + grn=False, + ) + + +class MedNeXtBase(MedNeXt): + """MedNeXt Base (B) configuration""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + norm_type="group", + grn=False, + ) + + +class MedNeXtMedium(MedNeXt): + """MedNeXt Medium (M)""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(3, 4, 4, 4), + blocks_bottleneck=4, + blocks_up=(4, 4, 4, 3), + norm_type="group", + grn=False, + ) + + +class MedNeXtLarge(MedNeXt): + """MedNeXt Large (L)""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=(3, 4, 8, 8), + decoder_expansion_ratio=(8, 8, 4, 3), + bottleneck_expansion_ratio=8, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(3, 4, 8, 8), + blocks_bottleneck=8, + blocks_up=(8, 8, 4, 3), + norm_type="group", + grn=False, + ) + + +MedNext = MedNeXt +MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall +MedNextB = MedNeXtB = MedNextBase = MedNeXtBase +MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium +MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge diff --git a/tests/test_mednext.py b/tests/test_mednext.py index e5f74118c3..e39e88f108 100644 --- a/tests/test_mednext.py +++ b/tests/test_mednext.py @@ -27,19 +27,17 @@ for init_filters in [8, 16]: for deep_supervision in [False, True]: for do_res in [False, True]: - for do_res_up_down in [False, True]: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": init_filters, - "deep_supervision": deep_supervision, - "do_res": do_res, - "do_res_up_down": do_res_up_down, - }, - (2, 1, *([16] * spatial_dims)), - (2, 2, *([16] * spatial_dims)), - ] - TEST_CASE_MEDNEXT.append(test_case) + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": init_filters, + "deep_supervision": deep_supervision, + "use_residual_connection": do_res, + }, + (2, 1, *([16] * spatial_dims)), + (2, 2, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT.append(test_case) TEST_CASE_MEDNEXT_2 = [] for spatial_dims in range(2, 4): From e146aaf297bfe833a7229ee9568976527a8e1256 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 26 Sep 2024 17:18:21 -0400 Subject: [PATCH 3/6] Update mednext implementations Signed-off-by: Suraj Pai --- monai/networks/nets/mednext.py | 147 ++++++++++++--------------------- tests/test_mednext.py | 29 ++++++- 2 files changed, 82 insertions(+), 94 deletions(-) diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py index e4e68bea20..4de4b7b442 100644 --- a/monai/networks/nets/mednext.py +++ b/monai/networks/nets/mednext.py @@ -266,128 +266,89 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: # Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975 -class MedNeXtSmall(MedNeXt): - """MedNeXt Small (S) configuration""" +def create_mednext( + variant: str, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, +) -> MedNeXt: + """ + Factory method to create MedNeXt variants. - def __init__( - self, - spatial_dims: int = 3, - in_channels: int = 1, - out_channels: int = 2, - kernel_size: int = 3, - deep_supervision: bool = False, - ): - super().__init__( - spatial_dims=spatial_dims, - init_filters=32, - in_channels=in_channels, - out_channels=out_channels, + Args: + variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L'). + spatial_dims (int): Number of spatial dimensions. Defaults to 3. + in_channels (int): Number of input channels. Defaults to 1. + out_channels (int): Number of output channels. Defaults to 2. + kernel_size (int): Kernel size for convolutions. Defaults to 3. + deep_supervision (bool): Whether to use deep supervision. Defaults to False. + + Returns: + MedNeXt: The specified MedNeXt variant. + + Raises: + ValueError: If an invalid variant is specified. + """ + common_args = { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "deep_supervision": deep_supervision, + "use_residual_connection": True, + "norm_type": "group", + "grn": False, + "init_filters": 32, + } + + if variant.upper() == "S": + return MedNeXt( encoder_expansion_ratio=2, decoder_expansion_ratio=2, bottleneck_expansion_ratio=2, - kernel_size=kernel_size, - deep_supervision=deep_supervision, - use_residual_connection=True, blocks_down=(2, 2, 2, 2), blocks_bottleneck=2, blocks_up=(2, 2, 2, 2), - norm_type="group", - grn=False, + **common_args, ) - - -class MedNeXtBase(MedNeXt): - """MedNeXt Base (B) configuration""" - - def __init__( - self, - spatial_dims: int = 3, - in_channels: int = 1, - out_channels: int = 2, - kernel_size: int = 3, - deep_supervision: bool = False, - ): - super().__init__( - spatial_dims=spatial_dims, - init_filters=32, - in_channels=in_channels, - out_channels=out_channels, + elif variant.upper() == "B": + return MedNeXt( encoder_expansion_ratio=(2, 3, 4, 4), decoder_expansion_ratio=(4, 4, 3, 2), bottleneck_expansion_ratio=4, - kernel_size=kernel_size, - deep_supervision=deep_supervision, - use_residual_connection=True, blocks_down=(2, 2, 2, 2), blocks_bottleneck=2, blocks_up=(2, 2, 2, 2), - norm_type="group", - grn=False, + **common_args, ) - - -class MedNeXtMedium(MedNeXt): - """MedNeXt Medium (M)""" - - def __init__( - self, - spatial_dims: int = 3, - in_channels: int = 1, - out_channels: int = 2, - kernel_size: int = 3, - deep_supervision: bool = False, - ): - super().__init__( - spatial_dims=spatial_dims, - init_filters=32, - in_channels=in_channels, - out_channels=out_channels, + elif variant.upper() == "M": + return MedNeXt( encoder_expansion_ratio=(2, 3, 4, 4), decoder_expansion_ratio=(4, 4, 3, 2), bottleneck_expansion_ratio=4, - kernel_size=kernel_size, - deep_supervision=deep_supervision, - use_residual_connection=True, blocks_down=(3, 4, 4, 4), blocks_bottleneck=4, blocks_up=(4, 4, 4, 3), - norm_type="group", - grn=False, + **common_args, ) - - -class MedNeXtLarge(MedNeXt): - """MedNeXt Large (L)""" - - def __init__( - self, - spatial_dims: int = 3, - in_channels: int = 1, - out_channels: int = 2, - kernel_size: int = 3, - deep_supervision: bool = False, - ): - super().__init__( - spatial_dims=spatial_dims, - init_filters=32, - in_channels=in_channels, - out_channels=out_channels, + elif variant.upper() == "L": + return MedNeXt( encoder_expansion_ratio=(3, 4, 8, 8), decoder_expansion_ratio=(8, 8, 4, 3), bottleneck_expansion_ratio=8, - kernel_size=kernel_size, - deep_supervision=deep_supervision, - use_residual_connection=True, blocks_down=(3, 4, 8, 8), blocks_bottleneck=8, blocks_up=(8, 8, 4, 3), - norm_type="group", - grn=False, + **common_args, ) + else: + raise ValueError(f"Invalid MedNeXt variant: {variant}") MedNext = MedNeXt -MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall -MedNextB = MedNeXtB = MedNextBase = MedNeXtBase -MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium -MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge +MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs) +MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs) +MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs) +MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs) diff --git a/tests/test_mednext.py b/tests/test_mednext.py index e39e88f108..55ec3cecb6 100644 --- a/tests/test_mednext.py +++ b/tests/test_mednext.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import MedNeXt +from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS from tests.utils import SkipIfBeforePyTorchVersion, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -55,6 +55,18 @@ ] TEST_CASE_MEDNEXT_2.append(test_case) +TEST_CASE_MEDNEXT_VARIANTS = [] +for model in [MedNeXtS, MedNeXtM, MedNeXtL]: + for spatial_dims in range(2, 4): + for out_channels in [1, 2]: + test_case = [ + model, + {"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels}, + (2, 1, *([16] * spatial_dims)), + (2, out_channels, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT_VARIANTS.append(test_case) + class TestMedNeXt(unittest.TestCase): @@ -91,6 +103,21 @@ def test_ill_arg(self): with self.assertRaises(AssertionError): MedNeXt(spatial_dims=4) + @parameterized.expand(TEST_CASE_MEDNEXT_VARIANTS) + def test_mednext_variants(self, model, input_param, input_shape, expected_shape): + net = model(**input_param).to(device) + + net.train() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + net.eval() + with torch.no_grad(): + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + if __name__ == "__main__": unittest.main() From 247012dff70683afe1ba281e89f5ad36ebb45e20 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 21:24:13 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mednext.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_mednext.py b/tests/test_mednext.py index 55ec3cecb6..4dca898f4a 100644 --- a/tests/test_mednext.py +++ b/tests/test_mednext.py @@ -18,7 +18,6 @@ from monai.networks import eval_mode from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" From 17d0579b0c453a820ccf42ce64b3233fad3b2b69 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Fri, 27 Sep 2024 19:37:55 -0400 Subject: [PATCH 5/6] Fix mypy errors Signed-off-by: Suraj Pai --- monai/networks/blocks/mednext_block.py | 41 ++++++++++---------------- monai/networks/nets/mednext.py | 16 +++++----- tests/test_mednext.py | 2 +- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py index 8d0f0433bd..d5420f92ed 100644 --- a/monai/networks/blocks/mednext_block.py +++ b/monai/networks/blocks/mednext_block.py @@ -21,6 +21,13 @@ all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"] +def get_conv_layer(spatial_dim: int = 3, transpose: bool = False): + if spatial_dim == 2: + return nn.ConvTranspose2d if transpose else nn.Conv2d + else: # spatial_dim == 3 + return nn.ConvTranspose3d if transpose else nn.Conv3d + + class MedNeXtBlock(nn.Module): def __init__( @@ -39,18 +46,9 @@ def __init__( self.do_res = use_residual_connection - assert dim in ["2d", "3d"] self.dim = dim - if self.dim == "2d": - conv = nn.Conv2d - normalized_shape = [in_channels, kernel_size, kernel_size] - grn_parameter_shape = (1, 1) - elif self.dim == "3d": - conv = nn.Conv3d - normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size] - grn_parameter_shape = (1, 1, 1) - else: - raise ValueError("dim must be either '2d' or '3d'") + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) + grn_parameter_shape = (1,) * (2 if dim == "2d" else 3) # First convolution layer with DepthWise Convolutions self.conv1 = conv( in_channels=in_channels, @@ -63,9 +61,11 @@ def __init__( # Normalization Layer. GroupNorm is used by default. if norm_type == "group": - self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) + self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore elif norm_type == "layer": - self.norm = nn.LayerNorm(normalized_shape=normalized_shape) + self.norm = nn.LayerNorm( + normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore + ) # Second convolution (Expansion) layer with Conv3D 1x1x1 self.conv2 = conv( in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0 @@ -131,10 +131,7 @@ def __init__( grn=grn, ) - if dim == "2d": - conv = nn.Conv2d - else: - conv = nn.Conv3d + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) self.resample_do_res = use_residual_connection if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) @@ -186,10 +183,7 @@ def __init__( self.resample_do_res = use_residual_connection self.dim = dim - if dim == "2d": - conv = nn.ConvTranspose2d - else: - conv = nn.ConvTranspose3d + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) @@ -228,10 +222,7 @@ class MedNeXtOutBlock(nn.Module): def __init__(self, in_channels, n_classes, dim): super().__init__() - if dim == "2d": - conv = nn.ConvTranspose2d - else: - conv = nn.ConvTranspose3d + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) self.conv_out = conv(in_channels, n_classes, kernel_size=1) def forward(self, x): diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py index 4de4b7b442..bb1d6b534c 100644 --- a/monai/networks/nets/mednext.py +++ b/monai/networks/nets/mednext.py @@ -72,8 +72,8 @@ def __init__( init_filters: int = 32, in_channels: int = 1, out_channels: int = 2, - encoder_expansion_ratio: int = 2, - decoder_expansion_ratio: int = 2, + encoder_expansion_ratio: Sequence[int] | int = 2, + decoder_expansion_ratio: Sequence[int] | int = 2, bottleneck_expansion_ratio: int = 2, kernel_size: int = 7, deep_supervision: bool = False, @@ -212,7 +212,7 @@ def __init__( out_blocks.reverse() self.out_blocks = nn.ModuleList(out_blocks) - def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: + def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]: """ Forward pass of the MedNeXt model. @@ -227,7 +227,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: x (torch.Tensor): Input tensor. Returns: - torch.Tensor or list[torch.Tensor]: Output tensor(s). + torch.Tensor or Sequence[torch.Tensor]: Output tensor(s). """ # Apply stem convolution x = self.stem(x) @@ -311,7 +311,7 @@ def create_mednext( blocks_down=(2, 2, 2, 2), blocks_bottleneck=2, blocks_up=(2, 2, 2, 2), - **common_args, + **common_args, # type: ignore ) elif variant.upper() == "B": return MedNeXt( @@ -321,7 +321,7 @@ def create_mednext( blocks_down=(2, 2, 2, 2), blocks_bottleneck=2, blocks_up=(2, 2, 2, 2), - **common_args, + **common_args, # type: ignore ) elif variant.upper() == "M": return MedNeXt( @@ -331,7 +331,7 @@ def create_mednext( blocks_down=(3, 4, 4, 4), blocks_bottleneck=4, blocks_up=(4, 4, 4, 3), - **common_args, + **common_args, # type: ignore ) elif variant.upper() == "L": return MedNeXt( @@ -341,7 +341,7 @@ def create_mednext( blocks_down=(3, 4, 8, 8), blocks_bottleneck=8, blocks_up=(8, 8, 4, 3), - **common_args, + **common_args, # type: ignore ) else: raise ValueError(f"Invalid MedNeXt variant: {variant}") diff --git a/tests/test_mednext.py b/tests/test_mednext.py index 4dca898f4a..b4ba4f9939 100644 --- a/tests/test_mednext.py +++ b/tests/test_mednext.py @@ -59,7 +59,7 @@ for spatial_dims in range(2, 4): for out_channels in [1, 2]: test_case = [ - model, + model, # type: ignore {"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels}, (2, 1, *([16] * spatial_dims)), (2, out_channels, *([16] * spatial_dims)), From efe93b2021ecf3fe0b2e207fc080a2e66cfbcbc0 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 3 Oct 2024 11:31:21 -0400 Subject: [PATCH 6/6] Add docstrings Signed-off-by: Suraj Pai --- monai/networks/blocks/mednext_block.py | 106 ++++++++++++++++++++++--- monai/networks/nets/mednext.py | 14 ++-- 2 files changed, 100 insertions(+), 20 deletions(-) diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py index d5420f92ed..0aa2bb6b58 100644 --- a/monai/networks/blocks/mednext_block.py +++ b/monai/networks/blocks/mednext_block.py @@ -29,6 +29,19 @@ def get_conv_layer(spatial_dim: int = 3, transpose: bool = False): class MedNeXtBlock(nn.Module): + """ + MedNeXtBlock class for the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (int): Whether to use residual connection. Defaults to True. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ def __init__( self, @@ -39,7 +52,7 @@ def __init__( use_residual_connection: int = True, norm_type: str = "group", dim="3d", - grn=False, + global_resp_norm=False, ): super().__init__() @@ -48,7 +61,7 @@ def __init__( self.dim = dim conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) - grn_parameter_shape = (1,) * (2 if dim == "2d" else 3) + global_resp_norm_param_shape = (1,) * (2 if dim == "2d" else 3) # First convolution layer with DepthWise Convolutions self.conv1 = conv( in_channels=in_channels, @@ -79,19 +92,27 @@ def __init__( in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) - self.grn = grn - if self.grn: - grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape - self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) - self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) + self.global_resp_norm = global_resp_norm + if self.global_resp_norm: + global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape + self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True) + self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True) def forward(self, x): + """ + Forward pass of the MedNeXtBlock. + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ x1 = x x1 = self.conv1(x1) x1 = self.act(self.conv2(self.norm(x1))) - if self.grn: + if self.global_resp_norm: # gamma, beta: learnable affine transform parameters # X: input of shape (N,C,H,W,D) if self.dim == "2d": @@ -99,7 +120,7 @@ def forward(self, x): else: gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True) nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6) - x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1 + x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1 x1 = self.conv3(x1) if self.do_res: x1 = x + x1 @@ -107,6 +128,19 @@ def forward(self, x): class MedNeXtDownBlock(MedNeXtBlock): + """ + MedNeXtDownBlock class for downsampling in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (bool): Whether to use residual connection. Defaults to False. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ def __init__( self, @@ -117,7 +151,7 @@ def __init__( use_residual_connection: bool = False, norm_type: str = "group", dim: str = "3d", - grn: bool = False, + global_resp_norm: bool = False, ): super().__init__( @@ -128,7 +162,7 @@ def __init__( use_residual_connection=False, norm_type=norm_type, dim=dim, - grn=grn, + global_resp_norm=global_resp_norm, ) conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) @@ -146,7 +180,15 @@ def __init__( ) def forward(self, x): + """ + Forward pass of the MedNeXtDownBlock. + + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output tensor. + """ x1 = super().forward(x) if self.resample_do_res: @@ -157,6 +199,19 @@ def forward(self, x): class MedNeXtUpBlock(MedNeXtBlock): + """ + MedNeXtUpBlock class for upsampling in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (bool): Whether to use residual connection. Defaults to False. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ def __init__( self, @@ -167,7 +222,7 @@ def __init__( use_residual_connection: bool = False, norm_type: str = "group", dim: str = "3d", - grn: bool = False, + global_resp_norm: bool = False, ): super().__init__( in_channels, @@ -177,7 +232,7 @@ def __init__( use_residual_connection=False, norm_type=norm_type, dim=dim, - grn=grn, + global_resp_norm=global_resp_norm, ) self.resample_do_res = use_residual_connection @@ -197,7 +252,15 @@ def __init__( ) def forward(self, x): + """ + Forward pass of the MedNeXtUpBlock. + + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output tensor. + """ x1 = super().forward(x) # Asymmetry but necessary to match shape @@ -218,6 +281,14 @@ def forward(self, x): class MedNeXtOutBlock(nn.Module): + """ + MedNeXtOutBlock class for the output block in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + n_classes (int): Number of output classes. + dim (str): Dimension of the input. Can be "2d" or "3d". + """ def __init__(self, in_channels, n_classes, dim): super().__init__() @@ -226,4 +297,13 @@ def __init__(self, in_channels, n_classes, dim): self.conv_out = conv(in_channels, n_classes, kernel_size=1) def forward(self, x): + """ + Forward pass of the MedNeXtOutBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ return self.conv_out(x) diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py index bb1d6b534c..427572ba60 100644 --- a/monai/networks/nets/mednext.py +++ b/monai/networks/nets/mednext.py @@ -63,7 +63,7 @@ class MedNeXt(nn.Module): blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2]. norm_type: type of normalization layer. Defaults to 'group'. - grn: whether to use Global Response Normalization (GRN). Defaults to False. + global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808 """ def __init__( @@ -82,7 +82,7 @@ def __init__( blocks_bottleneck: int = 2, blocks_up: Sequence[int] = (2, 2, 2, 2), norm_type: str = "group", - grn: bool = False, + global_resp_norm: bool = False, ): """ Initialize the MedNeXt model. @@ -126,7 +126,7 @@ def __init__( use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, - grn=grn, + global_resp_norm=global_resp_norm, ) for _ in range(num_blocks) ] @@ -158,7 +158,7 @@ def __init__( use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, - grn=grn, + global_resp_norm=global_resp_norm, ) for _ in range(blocks_bottleneck) ] @@ -176,7 +176,7 @@ def __init__( use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, - grn=grn, + global_resp_norm=global_resp_norm, ) ) @@ -191,7 +191,7 @@ def __init__( use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, - grn=grn, + global_resp_norm=global_resp_norm, ) for _ in range(num_blocks) ] @@ -299,7 +299,7 @@ def create_mednext( "deep_supervision": deep_supervision, "use_residual_connection": True, "norm_type": "group", - "grn": False, + "global_resp_norm": False, "init_filters": 32, }