Skip to content

Commit

Permalink
Fix: Inconsistent prediction outputs (#170)
Browse files Browse the repository at this point in the history
### Description

Prediction outputs were inconsistent. Firstly, if there was only 1
sample the output would not be a list. Secondly, and less predictably,
if the prediction was tiled, then singleton channel dimensions would be
dropped. This PR makes sure that the output of `CAREamist.predict` is
always a list and that the outputs always have the same dimensions.

**What**:
- The `stitched_prediction` function will now output arrays with
dimensions `SC(Z)YX` where S will always be singleton.
- Tests have been updated to always expect lists from
`CAREamist.predict`. Additionally, instead of squeezing prediction
outputs, the input arrays are reshaped when comparing dimensions; this
should make the tests a little more robust.

**Why**:
- Having `stitched_prediction` output `SC(Z)YX` arrays simplifies the
`prediction_utils.convert_outputs` function, in which, if a prediction
has 3 dimensions it can be ambiguous whether the first dimension is Z or
C. If the first dimension is Z then we need to add both the S & C
singleton dims, and if the first dimension is C then we only need to add
the S singleton dimension.

**How**: 
- Solved not very elegantly by expanding the `predicted_image`
dimensions, based on the first tile's dimensions, in
`stitch_prediction_single`.
- This can be solved more elegantly by allowing
`TileInformation.array_shape` to have singleton dimensions.

### Changes Made

**Modified**: 
  - `careamics.prediction_utils.convert_outputs`
  - `careamics.prediction_utils.stitch_prediction_single`
  - `CAREamsit.export_to_bmz`
  - tests  

### Related Issues

- Fixes #156 
- `stitch_prediction_single` now accepts inputs with dimensions
`SC(Z)YX`. There was an incorrect assumption that C was the first axis
before this PR.

### Breaking changes

Code working with prediction outputs.

### Additional Notes and Examples

Related to allowing `TileInformation.array_shape` to have singleton
dimensions, in my opinion, after data is initially input and reshaped,
we should always assume all internal code is working with data in the
form SC(Z)YX. This will reduce some ambiguity and confusion.

The prediction output dimensions could be changed to C(Z)YZ as I
initially suggested in #156, since S is always 1. This would mean
replacing `np.concatenate` with `np.stack` in a lot of places that
combine prediction samples.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
melisande-c authored Jul 3, 2024
1 parent da187e7 commit f223aa7
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 94 deletions.
23 changes: 6 additions & 17 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,27 +651,16 @@ def export_to_bmz(
data_description : str, optional
Description of the data, by default None.
"""
input_patch = reshape_array(input_array, self.cfg.data_config.axes)
# TODO: add in docs that it is expected that input_array dimensions match
# those in data_config

# axes need to be reformated for the export because reshaping was done in the
# datamodule
if "Z" in self.cfg.data_config.axes:
axes = "SCZYX"
else:
axes = "SCYX"

# predict output, remove extra dimensions for the purpose of the prediction
output_patch = self.predict(
input_patch,
input_array,
data_type=SupportedData.ARRAY.value,
axes=axes,
tta_transforms=False,
)

if isinstance(output_patch, list):
output = np.concatenate(output_patch, axis=0)
else:
output = output_patch
output = np.concatenate(output_patch, axis=0)
input_array = reshape_array(input_array, self.cfg.data_config.axes)

export_to_bmz(
model=self.model,
Expand All @@ -680,7 +669,7 @@ def export_to_bmz(
name=name,
general_description=general_description,
authors=authors,
input_array=input_patch,
input_array=input_array,
output_array=output,
channel_names=channel_names,
data_description=data_description,
Expand Down
60 changes: 15 additions & 45 deletions src/careamics/prediction_utils/prediction_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,10 @@ def convert_outputs(
# this layout is to stop mypy complaining
if tiled:
predictions_comb = combine_batches(predictions, tiled)
# remove sample dimension (always 1) `stitch_predict` func expects no S dim
tiles = [pred[0] for pred in predictions_comb[0]]
tile_infos = predictions_comb[1]
predictions_output = stitch_prediction(tiles, tile_infos)
predictions_output = stitch_prediction(*predictions_comb)
else:
predictions_output = combine_batches(predictions, tiled)

# TODO: add this in? Returns output with same axes as input
# Won't work with tiling rn because stitch_prediction func removes S axis
# predictions = reshape(predictions, axes)
# At least make sure stitched prediction and non-tiled prediction have matching axes

# TODO: might want to remove this
if len(predictions_output) == 1:
return predictions_output[0]
return predictions_output


Expand Down Expand Up @@ -94,7 +83,7 @@ def combine_batches(
if tiled:
return _combine_tiled_batches(predictions)
else:
return _combine_untiled_batches(predictions)
return _combine_array_batches(predictions)


def _combine_tiled_batches(
Expand All @@ -105,8 +94,11 @@ def _combine_tiled_batches(
Parameters
----------
predictions : list
Predictions that are output from `Trainer.predict`.
predictions : list of (numpy.ndarray, list of TileInformation)
Predictions that are output from `Trainer.predict`. For tiled batches, this is
a list of tuples. The first element of the tuples is the prediction output of
tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
element of the tuples is a list of TileInformation objects of length B.
Returns
-------
Expand All @@ -117,49 +109,27 @@ def _combine_tiled_batches(
tile_infos = [
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
]
prediction_tiles: List[NDArray] = _combine_untiled_batches(
prediction_tiles: List[NDArray] = _combine_array_batches(
[preds for preds, _ in predictions]
)
return prediction_tiles, tile_infos


def _combine_untiled_batches(predictions: List[NDArray]) -> List[NDArray]:
def _combine_array_batches(predictions: List[NDArray]) -> List[NDArray]:
"""
Combine batches from un-tiled output.
Combine batches of arrays.
Parameters
----------
predictions : list
Predictions that are output from `Trainer.predict`.
predictions : list
Prediction arrays that are output from `Trainer.predict`. A list of arrays that
have dimensions (B, C, (Z), Y, X), where B is batch size.
Returns
-------
list of nunpy.ndarray
Combined batches.
list of numpy.ndarray
A list of arrays with dimensions (1, C, (Z), Y, X).
"""
prediction_concat: NDArray = np.concatenate(predictions, axis=0)
prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
return prediction_split


def reshape(predictions: List[NDArray], axes: str) -> List[NDArray]:
"""
Reshape predictions to have dimensions of input.
Parameters
----------
predictions : list
Predictions that are output from `Trainer.predict`.
axes : str
Axes SC(Z)YX.
Returns
-------
List[NDArray]
Reshaped predicitions.
"""
if "C" not in axes:
predictions = [pred[:, 0] for pred in predictions]
if "S" not in axes:
predictions = [pred[0] for pred in predictions]
return predictions
31 changes: 17 additions & 14 deletions src/careamics/prediction_utils/stitch_prediction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Prediction utility functions."""

from typing import List
import builtins
from typing import List, Union

import numpy as np
from numpy.typing import NDArray

from careamics.config.tile_information import TileInformation

Expand Down Expand Up @@ -52,9 +54,9 @@ def stitch_prediction(


def stitch_prediction_single(
tiles: List[np.ndarray],
tiles: List[NDArray],
tile_infos: List[TileInformation],
) -> np.ndarray:
) -> NDArray:
"""
Stitch tiles back together to form a full image.
Expand All @@ -72,29 +74,30 @@ def stitch_prediction_single(
Returns
-------
numpy.ndarray
Full image.
Full image, with dimensions SC(Z)YX.
"""
# retrieve whole array size
input_shape = tile_infos[0].array_shape
predicted_image = np.zeros(input_shape, dtype=np.float32)

# reshape
# TODO: can be more elegantly solved if TileInformation allows singleton dims
singleton_dims = tuple(np.where(np.array(tiles[0].shape) == 1)[0])
predicted_image = np.expand_dims(predicted_image, singleton_dims)

for tile, tile_info in zip(tiles, tile_infos):
n_channels = tile.shape[0]

# Compute coordinates for cropping predicted tile
slices = (slice(0, n_channels),) + tuple(
[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords]
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
...,
*[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
)

# Crop predited tile according to overlap coordinates
cropped_tile = tile[slices]
cropped_tile = tile[crop_slices]

# Insert cropped tile into predicted image using stitch coordinates
predicted_image[
(
...,
*[slice(c[0], c[1]) for c in tile_info.stitch_coords],
)
] = cropped_tile.astype(np.float32)
image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
predicted_image[image_slices] = cropped_tile.astype(np.float32)

return predicted_image
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def pre_trained_bmz(tmp_path, pre_trained) -> Path:
careamist = CAREamist(source=pre_trained, work_dir=tmp_path)

# predict (no tiling and no tta)
predicted = careamist.predict(train_array, tta_transforms=False)
predicted_output = careamist.predict(train_array, tta_transforms=False)
predicted = np.concatenate(predicted_output, axis=0)

# export to BioImage Model Zoo
path = tmp_path / "model.zip"
Expand Down
6 changes: 4 additions & 2 deletions tests/model_io/test_bmz_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def test_state_dict_io(tmp_path, ordered_array, pre_trained):
careamist = CAREamist(source=pre_trained, work_dir=tmp_path)

# predict (no tiling and no tta)
predicted = careamist.predict(train_array, tta_transforms=False)
predicted_output = careamist.predict(train_array, tta_transforms=False)
predicted = np.concatenate(predicted_output, axis=0)

# save model
_export_state_dict(careamist.model, path)
Expand All @@ -39,7 +40,8 @@ def test_bmz_io(tmp_path, ordered_array, pre_trained):
careamist = CAREamist(source=pre_trained, work_dir=tmp_path)

# predict (no tiling and no tta)
predicted = careamist.predict(train_array, tta_transforms=False)
predicted_output = careamist.predict(train_array, tta_transforms=False)
predicted = np.concatenate(predicted_output, axis=0)

# export to BioImage Model Zoo
path = tmp_path / "model.zip"
Expand Down
12 changes: 5 additions & 7 deletions tests/prediction_utils/test_prediction_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,20 @@ def test_convert_outputs_tiled(ordered_array, batch_size, n_samples):
prediction_batches.append((tiles, tile_infos))

predictions = convert_outputs(prediction_batches, tiled=True)
# TODO: fix convert_outputs so output shape is the same as input shape
# (Or always SC(Z)YX)
assert np.array_equal(np.array(predictions), arr.squeeze())
assert np.array_equal(np.stack(predictions, axis=0).squeeze(), arr.squeeze())


@pytest.mark.parametrize("batch_size, n_samples", [(1, 1), (1, 2), (2, 2)])
def test_convert_outputs_not_tiled(ordered_array, batch_size, n_samples):
"""Test conversion of output for when prediction is not tiled"""
# --- simulate outputs from trainer.predict
# TODO: could test for case with different size batch at the end
prediction_batches = [
ordered_array((batch_size, 1, 16, 16)) for _ in range(n_samples // batch_size)
]
predictions = convert_outputs(prediction_batches, tiled=False)
if not isinstance(predictions, list): # single predictions not returned as list
predictions = [predictions]
assert np.array_equal(
np.concatenate(predictions, axis=0), np.concatenate(prediction_batches, axis=0)
# stack predictions because there is no S axis
# squeeze to remove singleton S or C axes
np.stack(predictions, axis=0).squeeze(),
np.concatenate(prediction_batches, axis=0).squeeze(),
)
4 changes: 2 additions & 2 deletions tests/prediction_utils/test_stitch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_stitch_tiles_single(ordered_array, input_shape, tile_size, overlaps):
result = stitch_prediction_single(tiles, tile_infos)

# check equality with the correct sample
assert np.array_equal(result, arr[sample_id].squeeze())
assert np.array_equal(result, arr[sample_id])
sample_id += 1

# clear the lists
Expand Down Expand Up @@ -85,6 +85,6 @@ def test_stitch_tiles_multi(ordered_array, input_shape, tile_size, overlaps):

stitched = stitch_prediction(tiles, tile_infos)
for sample_id, result in enumerate(stitched):
assert np.array_equal(result, arr[sample_id].squeeze())
assert np.array_equal(result, arr[sample_id])

assert len(stitched) == n_samples
27 changes: 21 additions & 6 deletions tests/test_careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from careamics import CAREamist, Configuration, save_configuration
from careamics.callbacks import HyperParametersCallback, ProgressBarCallback
from careamics.config.support import SupportedAlgorithm, SupportedData
from careamics.dataset.dataset_utils import reshape_array


def random_array(shape: Tuple[int, ...], seed: int = 42):
Expand Down Expand Up @@ -531,9 +532,11 @@ def test_predict_on_array_tiled(
predicted = careamist.predict(
train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4)
)
predicted_squeeze = [p.squeeze() for p in predicted]

assert np.array(predicted_squeeze).shape == train_array.squeeze().shape
assert (
np.concatenate(predicted).shape
== reshape_array(train_array, config.data_config.axes).shape
)

# export to BMZ
careamist.export_to_bmz(
Expand Down Expand Up @@ -572,7 +575,10 @@ def test_predict_arrays_no_tiling(
# predict CAREamist
predicted = careamist.predict(train_array, batch_size=batch_size)

assert np.concatenate(predicted).squeeze().shape == train_array.squeeze().shape
assert (
np.concatenate(predicted).shape
== reshape_array(train_array, config.data_config.axes).shape
)

# export to BMZ
careamist.export_to_bmz(
Expand Down Expand Up @@ -658,7 +664,10 @@ def test_predict_tiled_channel(
train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4)
)

assert predicted.squeeze().shape == train_array.shape
assert (
np.concatenate(predicted).shape
== reshape_array(train_array, config.data_config.axes).shape
)


@pytest.mark.parametrize("tiled", [True, False])
Expand Down Expand Up @@ -738,7 +747,10 @@ def test_predict_pretrained_checkpoint(tmp_path: Path, pre_trained: Path):
predicted = careamist.predict(source_array)

# check that it predicted
assert predicted.squeeze().shape == source_array.shape
assert (
np.concatenate(predicted).shape
== reshape_array(source_array, careamist.cfg.data_config.axes).shape
)


def test_predict_pretrained_bmz(tmp_path: Path, pre_trained_bmz: Path):
Expand All @@ -753,7 +765,10 @@ def test_predict_pretrained_bmz(tmp_path: Path, pre_trained_bmz: Path):
predicted = careamist.predict(source_array)

# check that it predicted
assert predicted.squeeze().shape == source_array.shape
assert (
np.concatenate(predicted).shape
== reshape_array(source_array, careamist.cfg.data_config.axes).shape
)


def test_export_bmz_pretrained_prediction(tmp_path: Path, pre_trained: Path):
Expand Down

0 comments on commit f223aa7

Please sign in to comment.