From 356d2d2f41b473f588899d705bbc682308cee52c Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Mon, 25 Jul 2022 17:06:58 +0100 Subject: [PATCH] 4757 update patch merging (#4758) * update patch merging Signed-off-by: Wenqi Li * fixes unit tests Signed-off-by: Wenqi Li --- monai/networks/nets/__init__.py | 2 +- monai/networks/nets/swin_unetr.py | 71 +++++++++++++++++++++++++++---- tests/test_swin_unetr.py | 6 ++- tests/utils.py | 4 ++ 4 files changed, 72 insertions(+), 11 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 6bf1868122..a4e8312b30 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -81,7 +81,7 @@ seresnext50, seresnext101, ) -from .swin_unetr import SwinUNETR +from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex from .unet import UNet, Unet diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index a6b0c5c6a7..f42e139027 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Type, Union +from typing import Optional, Sequence, Tuple, Type, Union import numpy as np import torch @@ -21,10 +21,23 @@ from monai.networks.blocks import MLPBlock as Mlp from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock from monai.networks.layers import DropPath, trunc_normal_ -from monai.utils import ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple_rep, look_up_option, optional_import rearrange, _ = optional_import("einops", name="rearrange") +__all__ = [ + "SwinUNETR", + "window_partition", + "window_reverse", + "WindowAttention", + "SwinTransformerBlock", + "PatchMerging", + "PatchMergingV2", + "MERGING_MODE", + "BasicLayer", + "SwinTransformer", +] + class SwinUNETR(nn.Module): """ @@ -48,6 +61,7 @@ def __init__( normalize: bool = True, use_checkpoint: bool = False, spatial_dims: int = 3, + downsample="merging", ) -> None: """ Args: @@ -64,6 +78,9 @@ def __init__( normalize: normalize output intermediate features in each stage. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: number of spatial dims. + downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a + user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. + The default is currently `"merging"` (the original version defined in v0.9.0). Examples:: @@ -121,6 +138,7 @@ def __init__( norm_layer=nn.LayerNorm, use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, + downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, ) self.encoder1 = UnetrBasicBlock( @@ -657,7 +675,7 @@ def forward(self, x, mask_matrix): return x -class PatchMerging(nn.Module): +class PatchMergingV2(nn.Module): """ Patch merging layer based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows @@ -695,8 +713,8 @@ def forward(self, x): x2 = x[:, 0::2, 1::2, 0::2, :] x3 = x[:, 0::2, 0::2, 1::2, :] x4 = x[:, 1::2, 0::2, 1::2, :] - x5 = x[:, 0::2, 1::2, 0::2, :] - x6 = x[:, 0::2, 0::2, 1::2, :] + x5 = x[:, 1::2, 1::2, 0::2, :] + x6 = x[:, 0::2, 1::2, 1::2, :] x7 = x[:, 1::2, 1::2, 1::2, :] x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) @@ -716,6 +734,36 @@ def forward(self, x): return x +class PatchMerging(PatchMergingV2): + """The `PatchMerging` module previously defined in v0.9.0.""" + + def forward(self, x): + x_shape = x.size() + if len(x_shape) == 4: + return super().forward(x) + if len(x_shape) != 5: + raise ValueError(f"expecting 5D x, got {x.shape}.") + b, d, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2)) + x0 = x[:, 0::2, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, 0::2, :] + x2 = x[:, 0::2, 1::2, 0::2, :] + x3 = x[:, 0::2, 0::2, 1::2, :] + x4 = x[:, 1::2, 0::2, 1::2, :] + x5 = x[:, 0::2, 1::2, 0::2, :] + x6 = x[:, 0::2, 0::2, 1::2, :] + x7 = x[:, 1::2, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) + x = self.norm(x) + x = self.reduction(x) + return x + + +MERGING_MODE = {"merging": PatchMerging, "mergingv2": PatchMergingV2} + + def compute_mask(dims, window_size, shift_size, device): """Computing region masks based on: "Liu et al., Swin Transformer: Hierarchical Vision Transformer using Shifted Windows @@ -776,7 +824,7 @@ def __init__( drop: float = 0.0, attn_drop: float = 0.0, norm_layer: Type[LayerNorm] = nn.LayerNorm, - downsample: isinstance = None, # type: ignore + downsample: Optional[nn.Module] = None, use_checkpoint: bool = False, ) -> None: """ @@ -791,7 +839,7 @@ def __init__( drop: dropout rate. attn_drop: attention dropout rate. norm_layer: normalization layer. - downsample: downsample layer at the end of the layer. + downsample: an optional downsampling layer at the end of the layer. use_checkpoint: use gradient checkpointing for reduced memory usage. """ @@ -820,7 +868,7 @@ def __init__( ] ) self.downsample = downsample - if self.downsample is not None: + if callable(self.downsample): self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) def forward(self, x): @@ -881,6 +929,7 @@ def __init__( patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, + downsample="merging", ) -> None: """ Args: @@ -899,6 +948,9 @@ def __init__( patch_norm: add normalization after patch embedding. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: spatial dimension. + downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a + user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. + The default is currently `"merging"` (the original version defined in v0.9.0). """ super().__init__() @@ -920,6 +972,7 @@ def __init__( self.layers2 = nn.ModuleList() self.layers3 = nn.ModuleList() self.layers4 = nn.ModuleList() + down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2**i_layer), @@ -932,7 +985,7 @@ def __init__( drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, - downsample=PatchMerging, + downsample=down_sample_mod, use_checkpoint=use_checkpoint, ) if i_layer == 0: diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index d102c36d54..6188d6225a 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -16,12 +16,14 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets.swin_unetr import PatchMerging, SwinUNETR +from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from monai.utils import optional_import einops, has_einops = optional_import("einops") TEST_CASE_SWIN_UNETR = [] +case_idx = 0 +test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2] for attn_drop_rate in [0.4]: for in_channels in [1]: for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: @@ -39,10 +41,12 @@ "depths": depth, "norm_name": norm_name, "attn_drop_rate": attn_drop_rate, + "downsample": test_merging_mode[case_idx % 4], }, (2, in_channels, *img_size), (2, out_channels, *img_size), ] + case_idx += 1 TEST_CASE_SWIN_UNETR.append(test_case) diff --git a/tests/utils.py b/tests/utils.py index 2a061b3ee4..22b1d96b1e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ import operator import os import queue +import ssl import sys import tempfile import time @@ -123,6 +124,9 @@ def skip_if_downloading_fails(): yield except (ContentTooShortError, HTTPError, ConnectionError) as e: raise unittest.SkipTest(f"error while downloading: {e}") from e + except ssl.SSLError as ssl_e: + if "decryption failed" in str(ssl_e): + raise unittest.SkipTest(f"SSL error while downloading: {ssl_e}") from ssl_e except RuntimeError as rt_e: if "unexpected EOF" in str(rt_e): raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download