diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 832135ad06..32b817d584 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -13,7 +13,6 @@ import itertools from collections.abc import Sequence -from typing import Final import numpy as np import torch @@ -51,8 +50,6 @@ class SwinUNETR(nn.Module): " """ - patch_size: Final[int] = 2 - @deprecated_arg( name="img_size", since="1.3", @@ -65,18 +62,24 @@ def __init__( img_size: Sequence[int] | int, in_channels: int, out_channels: int, + patch_size: int = 2, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (3, 6, 12, 24), + window_size: Sequence[int] | int = 7, + qkv_bias: bool = True, + mlp_ratio: float = 4.0, feature_size: int = 24, norm_name: tuple | str = "instance", drop_rate: float = 0.0, attn_drop_rate: float = 0.0, dropout_path_rate: float = 0.0, normalize: bool = True, + norm_layer: type[LayerNorm] = nn.LayerNorm, + patch_norm: bool = True, use_checkpoint: bool = False, spatial_dims: int = 3, - downsample="merging", - use_v2=False, + downsample: str | nn.Module = "merging", + use_v2: bool = False, ) -> None: """ Args: @@ -86,14 +89,20 @@ def __init__( It will be removed in an upcoming version. in_channels: dimension of input channels. out_channels: dimension of output channels. + patch_size: size of the patch token. feature_size: dimension of network feature size. depths: number of layers in each stage. num_heads: number of attention heads. + window_size: local window size. + qkv_bias: add a learnable bias to query, key, value. + mlp_ratio: ratio of mlp hidden dim to embedding dim. norm_name: feature normalization type and arguments. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. dropout_path_rate: drop path rate. normalize: normalize output intermediate features in each stage. + norm_layer: normalization layer. + patch_norm: whether to apply normalization to the patch embedding. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: number of spatial dims. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a @@ -116,13 +125,15 @@ def __init__( super().__init__() - img_size = ensure_tuple_rep(img_size, spatial_dims) - patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) - window_size = ensure_tuple_rep(7, spatial_dims) - if spatial_dims not in (2, 3): raise ValueError("spatial dimension should be 2 or 3.") + self.patch_size = patch_size + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) + window_size = ensure_tuple_rep(window_size, spatial_dims) + self._check_input_size(img_size) if not (0 <= drop_rate <= 1): @@ -146,12 +157,13 @@ def __init__( patch_size=patch_sizes, depths=depths, num_heads=num_heads, - mlp_ratio=4.0, - qkv_bias=True, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dropout_path_rate, - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, + patch_norm=patch_norm, use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,