From dfbae0bcc7152929c17af3dec7d1872b9733c155 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Mon, 27 May 2024 12:40:16 +0200 Subject: [PATCH] Force tiles to be a multiple of 2**depth when using UNet. Translational invariance is broken in UNets, due to the pooling layers. To avoid artefacts in during prediction, `tile_size` are now forced to be a multiple of 2**depth. This does not impact the "Lightning API", and can be an issue for small Z depth during prediction. In such a case, padding can be added or using prediction while by passing the CAREamist. --- src/careamics/careamist.py | 8 +- src/careamics/config/configuration_factory.py | 47 +++++++--- src/careamics/config/inference_model.py | 18 ++-- .../lightning_prediction_datamodule.py | 6 ++ tests/config/test_configuration_factory.py | 89 ++++++++++++++++++- 5 files changed, 143 insertions(+), 25 deletions(-) diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 7a9c8dae..0f4f8e91 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -543,6 +543,12 @@ def predict( Test-time augmentation (TTA) can be switched off using the `tta_transforms` parameter. + Note that if you are using a UNet model and tiling, the tile size must be + divisible in every dimension by 2**d, where d is the depth of the model. This + avoids artefacts arising from the broken shift invariance induced by the + pooling layers of the UNet. If your image has less dimensions, as it may + happen in the Z dimension, consider padding your image. + Parameters ---------- source : Union[CAREamicsClay, Path, str, np.ndarray] @@ -597,7 +603,7 @@ def predict( ) # create predict config, reuse training config if parameters missing prediction_config = create_inference_configuration( - training_configuration=self.cfg, + configuration=self.cfg, tile_size=tile_size, tile_overlap=tile_overlap, data_type=data_type, diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index f9454fba..84ef5d3a 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -583,9 +583,8 @@ def create_n2v_configuration( return configuration -# TODO add tests def create_inference_configuration( - training_configuration: Configuration, + configuration: Configuration, tile_size: Optional[Tuple[int, ...]] = None, tile_overlap: Optional[Tuple[int, ...]] = None, data_type: Optional[Literal["array", "tiff", "custom"]] = None, @@ -602,8 +601,8 @@ def create_inference_configuration( Parameters ---------- - training_configuration : Configuration - Configuration used for training. + configuration : Configuration + Global configuration. tile_size : Tuple[int, ...], optional Size of the tiles. tile_overlap : Tuple[int, ...], optional @@ -622,14 +621,12 @@ def create_inference_configuration( Returns ------- InferenceConfiguration - Configuration for inference with N2V. + Configuration used to configure CAREamicsPredictData. """ - if ( - training_configuration.data_config.mean is None - or training_configuration.data_config.std is None - ): - raise ValueError("Mean and std must be provided in the training configuration.") + if configuration.data_config.mean is None or configuration.data_config.std is None: + raise ValueError("Mean and std must be provided in the configuration.") + # minimum transform if transforms is None: transforms = [ { @@ -637,13 +634,35 @@ def create_inference_configuration( }, ] + # tile size for UNets + if tile_size is not None: + model = configuration.algorithm_config.model + + if model.architecture == SupportedArchitecture.UNET.value: + # tile size must be equal to k*2^n, where n is the number of pooling layers + # (equal to the depth) and k is an integer + depth = model.depth + tile_increment = 2**depth + + for i, t in enumerate(tile_size): + if t % tile_increment != 0: + raise ValueError( + f"Tile size must be divisible by {tile_increment} along all " + f"axes (got {t} for axis {i}). If your image size is smaller " + f"along one axis (e.g. Z), consider padding the image." + ) + + # tile overlaps must be specified + if tile_overlap is None: + raise ValueError("Tile overlap must be specified.") + return InferenceConfig( - data_type=data_type or training_configuration.data_config.data_type, + data_type=data_type or configuration.data_config.data_type, tile_size=tile_size, tile_overlap=tile_overlap, - axes=axes or training_configuration.data_config.axes, - mean=training_configuration.data_config.mean, - std=training_configuration.data_config.std, + axes=axes or configuration.data_config.axes, + mean=configuration.data_config.mean, + std=configuration.data_config.std, transforms=transforms, tta_transforms=tta_transforms, batch_size=batch_size, diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py index faa71f01..6ee37b7b 100644 --- a/src/careamics/config/inference_model.py +++ b/src/careamics/config/inference_model.py @@ -52,22 +52,22 @@ class InferenceConfig(BaseModel): @field_validator("tile_overlap") @classmethod def all_elements_non_zero_even( - cls, patch_list: Optional[Union[List[int]]] + cls, tile_overlap: Optional[Union[List[int]]] ) -> Optional[Union[List[int]]]: """ - Validate patch size. + Validate tile overlap. - Patch size must be non-zero, positive and even. + Overlaps must be non-zero, positive and even. Parameters ---------- - patch_list : Optional[Union[List[int]]] + tile_overlap : Optional[Union[List[int]]] Patch size. Returns ------- Optional[Union[List[int]]] - Validated patch size. + Validated tile overlap. Raises ------ @@ -76,8 +76,8 @@ def all_elements_non_zero_even( ValueError If the patch size is not even. """ - if patch_list is not None: - for dim in patch_list: + if tile_overlap is not None: + for dim in tile_overlap: if dim < 1: raise ValueError( f"Patch size must be non-zero positive (got {dim})." @@ -86,7 +86,7 @@ def all_elements_non_zero_even( if dim % 2 != 0: raise ValueError(f"Patch size must be even (got {dim}).") - return patch_list + return tile_overlap @field_validator("tile_size") @classmethod @@ -173,7 +173,7 @@ def validate_transforms( ValueError If transforms contain N2V pixel manipulate transforms. """ - if not isinstance(transforms, Compose) and transforms is not None: + if transforms is not None: for transform in transforms: if transform.name == SupportedTransform.N2V_MANIPULATE.value: raise ValueError( diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py index 670f4574..af2cbcaf 100644 --- a/src/careamics/lightning_prediction_datamodule.py +++ b/src/careamics/lightning_prediction_datamodule.py @@ -277,6 +277,12 @@ class PredictDataWrapper(CAREamicsPredictData): dataloaders, except for `batch_size`, which is set by the `batch_size` parameter. + Note that if you are using a UNet model and tiling, the tile size must be + divisible in every dimension by 2**d, where d is the depth of the model. This + avoids artefacts arising from the broken shift invariance induced by the + pooling layers of the UNet. If your image has less dimensions, as it may + happen in the Z dimension, consider padding your image. + Parameters ---------- pred_data : Union[str, Path, np.ndarray] diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factory.py index dc1e74cd..6b515585 100644 --- a/tests/config/test_configuration_factory.py +++ b/tests/config/test_configuration_factory.py @@ -2,6 +2,7 @@ from careamics.config import ( create_care_configuration, + create_inference_configuration, create_n2n_configuration, create_n2v_configuration, ) @@ -34,7 +35,7 @@ def test_n2n_configuration(): assert not config.algorithm_config.model.is_3D() -def test_cn2n_3d_configuration(): +def test_n2n_3d_configuration(): """Test that a 3D N2N configurationc an be created.""" config = create_care_configuration( experiment_name="test", @@ -519,3 +520,89 @@ def test_structn2v(): == SupportedStructAxis.HORIZONTAL.value ) assert config.data_config.transforms[-1].struct_mask_span == 7 + + +def test_inference_config_no_stats(): + """Test that an inference configuration fails if no statistics are present.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + + with pytest.raises(ValueError): + create_inference_configuration( + configuration=config, + ) + + +def test_inference_config(): + """Test that an inference configuration can be created.""" + config = create_n2v_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + config.data_config.set_mean_and_std(0.5, 0.2) + + inf_config = create_inference_configuration( + configuration=config, + ) + assert len(inf_config.transforms) == 1 + + +def test_inference_tile_size(): + """Test that an inference configuration can be created for a UNet model.""" + config = create_care_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + config.data_config.set_mean_and_std(0.5, 0.2) + + # check UNet depth, tile increment must then be a factor of 4 + assert config.algorithm_config.model.depth == 2 + + # error if not a factor of 4 + with pytest.raises(ValueError): + create_inference_configuration( + configuration=config, + tile_size=[6, 6], + tile_overlap=[2, 2], + ) + + # no error if a factor of 4 + create_inference_configuration( + configuration=config, + tile_size=[8, 8], + tile_overlap=[2, 2], + ) + + +def test_inference_tile_no_overlap(): + """Test that an error is raised if the tile overlap is not specified, but the tile + size is.""" + config = create_care_configuration( + experiment_name="test", + data_type="tiff", + axes="YX", + patch_size=[64, 64], + batch_size=8, + num_epochs=100, + ) + config.data_config.set_mean_and_std(0.5, 0.2) + + with pytest.raises(ValueError): + create_inference_configuration( + configuration=config, + tile_size=[8, 8], + )