Skip to content

Commit

Permalink
Enforce specific tile size in UNet to reduce artifacts (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored May 28, 2024
2 parents d4ae469 + d4f1d77 commit eea3ecd
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 24 deletions.
8 changes: 7 additions & 1 deletion src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,12 @@ def predict(
Test-time augmentation (TTA) can be switched off using the `tta_transforms`
parameter.
Note that if you are using a UNet model and tiling, the tile size must be
divisible in every dimension by 2**d, where d is the depth of the model. This
avoids artefacts arising from the broken shift invariance induced by the
pooling layers of the UNet. If your image has less dimensions, as it may
happen in the Z dimension, consider padding your image.
Parameters
----------
source : Union[CAREamicsClay, Path, str, np.ndarray]
Expand Down Expand Up @@ -597,7 +603,7 @@ def predict(
)
# create predict config, reuse training config if parameters missing
prediction_config = create_inference_configuration(
training_configuration=self.cfg,
configuration=self.cfg,
tile_size=tile_size,
tile_overlap=tile_overlap,
data_type=data_type,
Expand Down
47 changes: 33 additions & 14 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,8 @@ def create_n2v_configuration(
return configuration


# TODO add tests
def create_inference_configuration(
training_configuration: Configuration,
configuration: Configuration,
tile_size: Optional[Tuple[int, ...]] = None,
tile_overlap: Optional[Tuple[int, ...]] = None,
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
Expand All @@ -600,8 +599,8 @@ def create_inference_configuration(
Parameters
----------
training_configuration : Configuration
Configuration used for training.
configuration : Configuration
Global configuration.
tile_size : Tuple[int, ...], optional
Size of the tiles.
tile_overlap : Tuple[int, ...], optional
Expand All @@ -620,28 +619,48 @@ def create_inference_configuration(
Returns
-------
InferenceConfiguration
Configuration for inference with N2V.
Configuration used to configure CAREamicsPredictData.
"""
if (
training_configuration.data_config.mean is None
or training_configuration.data_config.std is None
):
raise ValueError("Mean and std must be provided in the training configuration.")
if configuration.data_config.mean is None or configuration.data_config.std is None:
raise ValueError("Mean and std must be provided in the configuration.")

# minimum transform
if transforms is None:
transforms = [
{
"name": SupportedTransform.NORMALIZE.value,
},
]

# tile size for UNets
if tile_size is not None:
model = configuration.algorithm_config.model

if model.architecture == SupportedArchitecture.UNET.value:
# tile size must be equal to k*2^n, where n is the number of pooling layers
# (equal to the depth) and k is an integer
depth = model.depth
tile_increment = 2**depth

for i, t in enumerate(tile_size):
if t % tile_increment != 0:
raise ValueError(
f"Tile size must be divisible by {tile_increment} along all "
f"axes (got {t} for axis {i}). If your image size is smaller "
f"along one axis (e.g. Z), consider padding the image."
)

# tile overlaps must be specified
if tile_overlap is None:
raise ValueError("Tile overlap must be specified.")

return InferenceConfig(
data_type=data_type or training_configuration.data_config.data_type,
data_type=data_type or configuration.data_config.data_type,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes or training_configuration.data_config.axes,
mean=training_configuration.data_config.mean,
std=training_configuration.data_config.std,
axes=axes or configuration.data_config.axes,
mean=configuration.data_config.mean,
std=configuration.data_config.std,
transforms=transforms,
tta_transforms=tta_transforms,
batch_size=batch_size,
Expand Down
16 changes: 8 additions & 8 deletions src/careamics/config/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,22 @@ class InferenceConfig(BaseModel):
@field_validator("tile_overlap")
@classmethod
def all_elements_non_zero_even(
cls, patch_list: Optional[Union[List[int]]]
cls, tile_overlap: Optional[Union[List[int]]]
) -> Optional[Union[List[int]]]:
"""
Validate patch size.
Validate tile overlap.
Patch size must be non-zero, positive and even.
Overlaps must be non-zero, positive and even.
Parameters
----------
patch_list : Optional[Union[List[int]]]
tile_overlap : Optional[Union[List[int]]]
Patch size.
Returns
-------
Optional[Union[List[int]]]
Validated patch size.
Validated tile overlap.
Raises
------
Expand All @@ -75,8 +75,8 @@ def all_elements_non_zero_even(
ValueError
If the patch size is not even.
"""
if patch_list is not None:
for dim in patch_list:
if tile_overlap is not None:
for dim in tile_overlap:
if dim < 1:
raise ValueError(
f"Patch size must be non-zero positive (got {dim})."
Expand All @@ -85,7 +85,7 @@ def all_elements_non_zero_even(
if dim % 2 != 0:
raise ValueError(f"Patch size must be even (got {dim}).")

return patch_list
return tile_overlap

@field_validator("tile_size")
@classmethod
Expand Down
6 changes: 6 additions & 0 deletions src/careamics/lightning_prediction_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ class PredictDataWrapper(CAREamicsPredictData):
dataloaders, except for `batch_size`, which is set by the `batch_size`
parameter.
Note that if you are using a UNet model and tiling, the tile size must be
divisible in every dimension by 2**d, where d is the depth of the model. This
avoids artefacts arising from the broken shift invariance induced by the
pooling layers of the UNet. If your image has less dimensions, as it may
happen in the Z dimension, consider padding your image.
Parameters
----------
pred_data : Union[str, Path, np.ndarray]
Expand Down
89 changes: 88 additions & 1 deletion tests/config/test_configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from careamics.config import (
create_care_configuration,
create_inference_configuration,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down Expand Up @@ -32,7 +33,7 @@ def test_n2n_configuration():
assert not config.algorithm_config.model.is_3D()


def test_cn2n_3d_configuration():
def test_n2n_3d_configuration():
"""Test that a 3D N2N configurationc an be created."""
config = create_care_configuration(
experiment_name="test",
Expand Down Expand Up @@ -507,3 +508,89 @@ def test_structn2v():
== SupportedStructAxis.HORIZONTAL.value
)
assert config.data_config.transforms[-1].struct_mask_span == 7


def test_inference_config_no_stats():
"""Test that an inference configuration fails if no statistics are present."""
config = create_n2v_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)

with pytest.raises(ValueError):
create_inference_configuration(
configuration=config,
)


def test_inference_config():
"""Test that an inference configuration can be created."""
config = create_n2v_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)
config.data_config.set_mean_and_std(0.5, 0.2)

inf_config = create_inference_configuration(
configuration=config,
)
assert len(inf_config.transforms) == 1


def test_inference_tile_size():
"""Test that an inference configuration can be created for a UNet model."""
config = create_care_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)
config.data_config.set_mean_and_std(0.5, 0.2)

# check UNet depth, tile increment must then be a factor of 4
assert config.algorithm_config.model.depth == 2

# error if not a factor of 4
with pytest.raises(ValueError):
create_inference_configuration(
configuration=config,
tile_size=[6, 6],
tile_overlap=[2, 2],
)

# no error if a factor of 4
create_inference_configuration(
configuration=config,
tile_size=[8, 8],
tile_overlap=[2, 2],
)


def test_inference_tile_no_overlap():
"""Test that an error is raised if the tile overlap is not specified, but the tile
size is."""
config = create_care_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)
config.data_config.set_mean_and_std(0.5, 0.2)

with pytest.raises(ValueError):
create_inference_configuration(
configuration=config,
tile_size=[8, 8],
)

0 comments on commit eea3ecd

Please sign in to comment.