diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 402d0eed..798c0464 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -674,6 +674,45 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None, return img_norm +def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR): + """OpenCV resize function does not support uint32. + + This function converts the image to float32 before resizing and then converts it back to uint32. Not safe! + References issue: https://github.com/MouseLand/cellpose/issues/937 + + Implications: + * Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not + a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU. + * Memory: However, memory usage increases. Not tested by how much. + + Args: + img (ndarray): Image of size [Ly x Lx]. + Ly (int): Desired height of the resized image. + Lx (int): Desired width of the resized image. + interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR. + + Returns: + ndarray: Resized image of size [Ly x Lx]. + + """ + + # cast image + cast = img.dtype == np.uint32 + if cast: + # + img = img.astype(np.float32) + + # resize + img = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + + # cast back + if cast: + transforms_logger.warning("resizing image from uint32 to float32 and back to uint32") + img = img.round().astype(np.uint32) + + return img + + def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR, no_channels=False): """Resize image for computing flows / unresize for computing dynamics. @@ -721,9 +760,9 @@ def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEA else: imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), np.float32) for i, img in enumerate(img0): - imgs[i] = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + imgs[i] = resize_safe(img, Ly, Lx, interpolation=interpolation) else: - imgs = cv2.resize(img0, (Lx, Ly), interpolation=interpolation) + imgs = resize_safe(img0, Ly, Lx, interpolation=interpolation) return imgs diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 91f40d7e..10c20c76 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -21,3 +21,22 @@ def test_normalize_img(data_dir): img_norm = normalize_img(img, norm3D=False, sharpen_radius=8) assert img_norm.shape == img.shape + +def test_resize(data_dir): + img = io.imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif'))) + + Lx = 100 + Ly = 200 + + img8 = resize_image(img.astype("uint8"), Lx=Lx, Ly=Ly) + assert img8.shape == (Ly, Lx, 3) + assert img8.dtype == np.uint8 + + img16 = resize_image(img.astype("uint16"), Lx=Lx, Ly=Ly) + assert img16.shape == (Ly, Lx, 3) + assert img16.dtype == np.uint16 + + img32 = resize_image(img.astype("uint32"), Lx=Lx, Ly=Ly) + assert img32.shape == (Ly, Lx, 3) + assert img32.dtype == np.uint32 +