diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b4ae562911..5bc38f69ea 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -12,8 +12,11 @@ from __future__ import annotations import glob +import gzip +import io import os import re +import tempfile import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -51,6 +54,9 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) +cp, has_cp = optional_import("cupy") +kvikio, has_kvikio = optional_import("kvikio") + __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict): ) -def _stack_images(image_list: list, meta_dict: dict): +def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False): if len(image_list) <= 1: return image_list[0] if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + if to_cupy and has_cp: + return cp.concatenate(image_list, axis=channel_dim) return np.concatenate(image_list, axis=channel_dim) # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + if to_cupy and has_cp: + return cp.stack(image_list, axis=0) return np.stack(image_list, axis=0) @@ -864,12 +874,18 @@ class NibabelReader(ImageReader): Load NIfTI format images based on Nibabel library. Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) channel_dim: the channel dimension of the input image, default is None. this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. if None, `original_channel_dim` will be either `no_channel` or `-1`. most Nifti files are usually "channel last", no need to specify this argument for them. + as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) + to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading. + Default is False. CuPy and Kvikio are required for this option. + Note: For compressed NIfTI files, some operations may still be performed on CPU memory, + and the acceleration may not be significant. In some cases, it may be slower than loading on CPU. + In practical use, it's recommended to add a warm up call before the actual loading. + A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py @@ -880,14 +896,42 @@ def __init__( channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, + to_gpu: bool = False, **kwargs, ): super().__init__() self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + if to_gpu and (not has_cp or not has_kvikio): + warnings.warn( + "NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading." + ) + to_gpu = False + + if to_gpu: + self.warmup_kvikio() + + self.to_gpu = to_gpu self.kwargs = kwargs + def warmup_kvikio(self): + """ + Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc. + This can accelerate the data loading process when `to_gpu` is set to True. + """ + if has_cp and has_kvikio: + a = cp.arange(100) + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b) + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = filenames kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ + # TODO: the actual type is list[np.ndarray | cp.ndarray] + # should figure out how to define correct types without having cupy not found error + # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 img_array: list[np.ndarray] = [] compatible_meta: dict = {} - for i in ensure_tuple(img): + for i, filename in zip(ensure_tuple(img), self.filenames): header = self._get_meta_dict(i) header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) @@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) + data = self._get_array_data(i, filename) if self.squeeze_non_spatial_dims: for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): if data.shape[d - 1] == 1: @@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - return _stack_images(img_array, compatible_meta), compatible_meta + return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta def _get_meta_dict(self, img) -> dict: """ @@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img): spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) - def _get_array_data(self, img): + def _get_array_data(self, img, filename): """ Get the raw array data of the image, converted to Numpy array. Args: img: a Nibabel image object loaded from an image file. - - """ + filename: file name of the image. + + """ + if self.to_gpu: + file_size = os.path.getsize(filename) + image = cp.empty(file_size, dtype=cp.uint8) + with kvikio.CuFile(filename, "r") as f: + f.read(image) + if filename.endswith(".nii.gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # and may be slower than CPU loading in some cases. + warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.") + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + image = cp.frombuffer(decompressed_data, dtype=cp.uint8) + data_shape = img.shape + data_offset = img.dataobj.offset + data_dtype = img.dataobj.dtype + return image[data_offset:].view(data_dtype).reshape(data_shape, order="F") return np.asanyarray(img.dataobj, order="C") diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ac171e8508..c4c491e1b9 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -553,7 +553,6 @@ def ensure_torch_and_prune_meta( However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray - # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e71870fc9..1023cd7a7d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -286,7 +286,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered: {self.readers}.\n{msg}" ) - img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index cb45cb5146..8331f742ec 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -30,6 +30,17 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_load_image_to_gpu(self): + for to_gpu in [True, False]: + instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance1, LoadImage) + + instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance2, LoadImaged) + @SkipIfNoModule("itk") @SkipIfNoModule("nibabel") @SkipIfNoModule("PIL") @@ -58,6 +69,14 @@ def test_readers(self): inst = NrrdReader() self.assertIsInstance(inst, NrrdReader) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_readers_to_gpu(self): + for to_gpu in [True, False]: + inst = NibabelReader(to_gpu=to_gpu) + self.assertIsInstance(inst, NibabelReader) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 0207079d7d..a3e6d7bcfc 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -29,7 +29,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImage from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config +from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config itk, has_itk = optional_import("itk", allow_namespace_pkg=True) ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator") @@ -74,6 +74,22 @@ def get_data(self, _obj): TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)] + +TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)] + +TEST_CASE_GPU_3 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii", "test_image2.nii", "test_image3.nii"], + (3, 128, 128, 128), +] + +TEST_CASE_GPU_4 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + (3, 128, 128, 128), +] + TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] @@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4]) + def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): + test_image = np.random.rand(128, 128, 128) + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadImage(image_only=True, **input_param)(filenames) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + self.assertEqual(result.meta["space"], "RAS") + assert_allclose(result.affine, torch.eye(4)) + self.assertTupleEqual(result.shape, expected_shape) + + # verify gpu and cpu loaded data are the same + input_param_cpu = input_param.copy() + input_param_cpu["to_gpu"] = False + result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames) + self.assertTrue(torch.equal(result_cpu, result.cpu())) + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): test_image = np.random.rand(128, 128, 128)