Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix intensity normalizations #981

Merged
merged 6 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 86 additions & 61 deletions cellpose/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""

import numpy as np
import logging
import warnings

import cv2
import numpy as np
import torch
from torch.fft import fft2, ifft2, fftshift
from scipy.ndimage import gaussian_filter1d

import logging
from torch.fft import fft2, fftshift, ifft2

transforms_logger = logging.getLogger(__name__)

from . import dynamics, utils


def _taper_mask(ly=224, lx=224, sig=7.5):
"""
Expand Down Expand Up @@ -492,7 +490,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha
transforms_logger.warning(f"z_axis not specified, assuming it is dim {z_axis}")
transforms_logger.warning(f"if this is actually the channel_axis, use 'model.eval(channel_axis={z_axis}, ...)'")
z_axis = 0

if z_axis is not None:
if x.ndim == 3:
x = x[..., np.newaxis]
Expand All @@ -512,7 +510,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha

if channel_axis is None:
x = move_min_dim(x)

if x.ndim > 3:
transforms_logger.info(
"multi-stack tiff read in as having %d planes %d channels" %
Expand All @@ -533,7 +531,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha
% (nchan, nchan))
x = x[..., :nchan]

#if not do_3D and x.ndim > 3:
# if not do_3D and x.ndim > 3:
# transforms_logger.critical("ERROR: cannot process 4D images in 2D mode")
# raise ValueError("ERROR: cannot process 4D images in 2D mode")

Expand Down Expand Up @@ -598,21 +596,24 @@ def reshape(data, channels=[0, 0], chan_first=False):


def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None,
percentile=None, sharpen_radius=0, smooth_radius=0,
percentile=(1., 99.), sharpen_radius=0, smooth_radius=0,
tile_norm_blocksize=0, tile_norm_smooth3D=1, axis=-1):
"""Normalize each channel of the image.
"""Normalize each channel of the image with optional inversion, smoothing, and sharpening.

Args:
img (ndarray): The input image. It should have at least 3 dimensions.
If it is 4-dimensional, it assumes the first non-channel axis is the Z dimension.
normalize (bool, optional): Whether to perform normalization. Defaults to True.
norm3D (bool, optional): Whether to normalize in 3D. Defaults to False.
norm3D (bool, optional): Whether to normalize in 3D. If True, the entire 3D stack will
be normalized per channel. If False, normalization is applied per Z-slice. Defaults to False.
invert (bool, optional): Whether to invert the image. Useful if cells are dark instead of bright.
Defaults to False.
lowhigh (tuple, optional): The lower and upper bounds for normalization. If provided, it should be a tuple
of two values. Defaults to None.
lowhigh (tuple or ndarray, optional): The lower and upper bounds for normalization.
Can be a tuple of two values (applied to all channels) or an array of shape (nchan, 2)
for per-channel normalization. Incompatible with smoothing and sharpening.
Defaults to None.
percentile (tuple, optional): The lower and upper percentiles for normalization. If provided, it should be
a tuple of two values. Each value should be between 0 and 100. Defaults to None.
a tuple of two values. Each value should be between 0 and 100. Defaults to (1.0, 99.0).
sharpen_radius (int, optional): The radius for sharpening the image. Defaults to 0.
smooth_radius (int, optional): The radius for smoothing the image. Defaults to 0.
tile_norm_blocksize (int, optional): The block size for tile-based normalization. Defaults to 0.
Expand All @@ -633,96 +634,120 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None,
transforms_logger.critical(error_message)
raise ValueError(error_message)

if lowhigh is not None:
assert len(lowhigh) == 2
assert lowhigh[1] > lowhigh[0]
elif percentile is not None:
assert len(percentile) == 2
assert percentile[0] >= 0 and percentile[1] > 0
assert percentile[0] < 100 and percentile[1] <= 100
assert percentile[1] > percentile[0]
else:
percentile = [1., 99.]

img_norm = img.astype(np.float32)
# move channel axis last
img_norm = np.moveaxis(img_norm, axis, -1)
img_norm = np.moveaxis(img_norm, axis, -1) # Move channel axis to last

nchan = img_norm.shape[-1]

# Validate and handle lowhigh bounds
if lowhigh is not None:
lowhigh = np.array(lowhigh)
if lowhigh.shape == (2,):
lowhigh = np.tile(lowhigh, (nchan, 1)) # Expand to per-channel bounds
elif lowhigh.shape != (nchan, 2):
error_message = "`lowhigh` must have shape (2,) or (nchan, 2)"
transforms_logger.critical(error_message)
raise ValueError(error_message)

# Validate percentile
if percentile is None:
percentile = (1.0, 99.0)
elif not (0 <= percentile[0] < percentile[1] <= 100):
error_message = "Invalid percentile range, should be between 0 and 100"
transforms_logger.critical(error_message)
raise ValueError(error_message)

# Apply normalization based on lowhigh or percentile
if lowhigh is not None:
for c in range(nchan):
img_norm[...,
c] = (img_norm[..., c] - lowhigh[0]) / (lowhigh[1] - lowhigh[0])
lower = lowhigh[c, 0]
upper = lowhigh[c, 1]
img_norm[..., c] = (img_norm[..., c] - lower) / (upper - lower)

else:
# Apply sharpening and smoothing if specified
if sharpen_radius > 0 or smooth_radius > 0:
img_norm = smooth_sharpen_img(img_norm, sharpen_radius=sharpen_radius,
smooth_radius=smooth_radius)
img_norm = smooth_sharpen_img(
img_norm, sharpen_radius=sharpen_radius, smooth_radius=smooth_radius
)

# Apply tile-based normalization or standard normalization
if tile_norm_blocksize > 0:
img_norm = normalize99_tile(img_norm, blocksize=tile_norm_blocksize,
lower=percentile[0], upper=percentile[1],
smooth3D=tile_norm_smooth3D, norm3D=norm3D)
img_norm = normalize99_tile(
img_norm,
blocksize=tile_norm_blocksize,
lower=percentile[0],
upper=percentile[1],
smooth3D=tile_norm_smooth3D,
norm3D=norm3D,
)
elif normalize:
if img_norm.ndim == 3 or norm3D:
if img_norm.ndim == 3 or norm3D: # i.e. if YXC, or ZYXC with norm3D=True
for c in range(nchan):
img_norm[..., c] = normalize99(img_norm[...,
c], lower=percentile[0],
upper=percentile[1], copy=False)
else:
img_norm[..., c] = normalize99(
img_norm[..., c],
lower=percentile[0],
upper=percentile[1],
copy=False,
)
else: # i.e. if ZYXC with norm3D=False then per Z-slice
for z in range(img_norm.shape[0]):
for c in range(nchan):
img_norm[z, :, :,
c] = normalize99(img_norm[z, :, :,
c], lower=percentile[0],
upper=percentile[1], copy=False)
if (tile_norm_blocksize > 0 or normalize) and invert:
img_norm[..., c] = -1 * img_norm[..., c] + 1
elif invert:
error_message = "cannot invert image without normalizing"
img_norm[z, ..., c] = normalize99(
img_norm[z, ..., c],
lower=percentile[0],
upper=percentile[1],
copy=False,
)

if invert:
if lowhigh is not None or tile_norm_blocksize > 0 or normalize:
img_norm = 1 - img_norm
else:
error_message = "Cannot invert image without normalization"
transforms_logger.critical(error_message)
raise ValueError(error_message)

# move channel axis back to original position
# Move channel axis back to the original position
img_norm = np.moveaxis(img_norm, -1, axis)

return img_norm

def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR):
"""OpenCV resize function does not support uint32.
"""OpenCV resize function does not support uint32.

This function converts the image to float32 before resizing and then converts it back to uint32. Not safe!
References issue: https://github.com/MouseLand/cellpose/issues/937

Implications:
* Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not
* Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not
a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU.
* Memory: However, memory usage increases. Not tested by how much.

Args:
img (ndarray): Image of size [Ly x Lx].
Ly (int): Desired height of the resized image.
Lx (int): Desired width of the resized image.
interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.

Returns:
ndarray: Resized image of size [Ly x Lx].

"""

# cast image
cast = img.dtype == np.uint32
if cast:
#
img = img.astype(np.float32)

# resize
img = cv2.resize(img, (Lx, Ly), interpolation=interpolation)

# cast back
if cast:
transforms_logger.warning("resizing image from uint32 to float32 and back to uint32")
img = img.round().astype(np.uint32)

return img


Expand Down
106 changes: 84 additions & 22 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,104 @@
from cellpose.transforms import *
from cellpose import io
import numpy as np
import pytest

from cellpose.io import imread
from cellpose.transforms import normalize_img, random_rotate_and_resize, resize_image


@pytest.fixture
def img_3d(data_dir):
"""Fixture to load 3D image data for tests."""
img = imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif')))
return img.transpose(0, 2, 3, 1).astype('float32')


@pytest.fixture
def img_2d(data_dir):
"""Fixture to load 2D image data for tests."""
return imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif')))


def test_random_rotate_and_resize__default():
nimg = 2
X = [np.random.rand(64, 64) for i in range(nimg)]

random_rotate_and_resize(X)


def test_normalize_img(data_dir):
img = io.imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif')))
img = img.transpose(0, 2, 3, 1).astype('float32')
def test_normalize_img(img_3d):
img_norm = normalize_img(img_3d, norm3D=True)
assert img_norm.shape == img_3d.shape

img_norm = normalize_img(img_3d, norm3D=True, tile_norm_blocksize=25)
assert img_norm.shape == img_3d.shape

img_norm = normalize_img(img_3d, norm3D=False, sharpen_radius=8)
assert img_norm.shape == img_3d.shape


def test_normalize_img_with_lowhigh_and_invert(img_3d):
img_norm = normalize_img(img_3d, lowhigh=(img_3d.min() + 1, img_3d.max() - 1))
assert img_norm.min() < 0 and img_norm.max() > 1

img_norm = normalize_img(img_3d, lowhigh=(img_3d.min(), img_3d.max()))
assert 0 <= img_norm.min() < img_norm.max() <= 1

img_norm = normalize_img(img, norm3D=True)
assert img_norm.shape == img.shape
img_norm_channelwise = normalize_img(
img_3d,
lowhigh=(
(img_3d[..., 0].min(), img_3d[..., 0].max()),
(img_3d[..., 1].min(), img_3d[..., 1].max()),
),
)
assert img_norm_channelwise.min() >= 0 and img_norm_channelwise.max() <= 1

img_norm = normalize_img(img, norm3D=True, tile_norm_blocksize=25)
assert img_norm.shape == img.shape
img_norm_channelwise_inverted = normalize_img(
img_3d,
lowhigh=(
(img_3d[..., 0].min(), img_3d[..., 0].max()),
(img_3d[..., 1].min(), img_3d[..., 1].max()),
),
invert=True,
)
np.testing.assert_allclose(
img_norm_channelwise, 1 - img_norm_channelwise_inverted, rtol=1e-3
)

img_norm = normalize_img(img, norm3D=False, sharpen_radius=8)
assert img_norm.shape == img.shape

def test_resize(data_dir):
img = io.imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif')))

def test_normalize_img_exceptions(img_3d):
img_2D = img_3d[0, ..., 0]
with pytest.raises(ValueError):
normalize_img(img_2D)

with pytest.raises(ValueError):
normalize_img(img_3d, lowhigh=(0, 1, 2))

with pytest.raises(ValueError):
normalize_img(img_3d, lowhigh=((0, 1), (0, 1, 2)))

with pytest.raises(ValueError):
normalize_img(img_3d, lowhigh=((0, 1),) * 4)

with pytest.raises(ValueError):
normalize_img(img_3d, percentile=(1, 101))

with pytest.raises(ValueError):
normalize_img(
img_3d, lowhigh=None, tile_norm_blocksize=0, normalize=False, invert=True
)


def test_resize(img_2d):
Lx = 100
Ly = 200
img8 = resize_image(img.astype("uint8"), Lx=Lx, Ly=Ly)

img8 = resize_image(img_2d.astype("uint8"), Lx=Lx, Ly=Ly)
assert img8.shape == (Ly, Lx, 3)
assert img8.dtype == np.uint8
img16 = resize_image(img.astype("uint16"), Lx=Lx, Ly=Ly)

img16 = resize_image(img_2d.astype("uint16"), Lx=Lx, Ly=Ly)
assert img16.shape == (Ly, Lx, 3)
assert img16.dtype == np.uint16
img32 = resize_image(img.astype("uint32"), Lx=Lx, Ly=Ly)

img32 = resize_image(img_2d.astype("uint32"), Lx=Lx, Ly=Ly)
assert img32.shape == (Ly, Lx, 3)
assert img32.dtype == np.uint32